优秀的编程知识分享平台

网站首页 > 技术文章 正文

计算卷积神经网络参数总数和输出形状

nanyue 2024-11-23 20:18:41 技术文章 2 ℃

在本文中,我们将讨论卷积层中的两个重要概念。

  • 如何计算参数的数量?
  • 产出的形状是如何计算的?

术语

input_shape

Input_shape = (batch_size, height, width, depth)
Batch_size =一次向前/向后传递的训练数据数

output_shape

Output_shape = (batch_size, height, width, depth)

过滤器/核

在卷积神经网络中,输入数据与卷积核进行卷积,卷积核用于提取特征。卷积核是一个矩阵,它将移动到图像像素数据(输入)上,并将执行与输入数据的特定区域的点积,输出将是点积的矩阵。

计算卷积层中输出的参数个数和形状

示例1

输入:

filter= 1
kernel_size = (3)
input_shape =(10、10、1)

让我们计算Conv2D中的参数数量和输出形状。

如何计算卷积层中的参数个数?

权重:(3,3)= 3*3 =9的卷积核

偏置: 1
[每个卷积核将添加一个偏置。由于只使用了一个卷积核,偏置=1*1]

一个大小为(3,3)的滤波器核的总参数= 9+1 =10

如何计算输出形状?

S→stride, p→padding, n→input size, f→filter size

默认Stride =1,没有提到填充(所以,p=0)

输出形状= n-f+1 = 10-3 +1 =8

在使用卷积滤波器对输入图像应用卷积后,输出将是一个特征映射。特征图中的通道数量取决于所使用卷积核的数量。在这个示例中,只使用了一个卷积核。因此,特征图中的通道数为1。

因此,特征图的Output_shape = (8,8,1)

model1=keras.models.Sequential()
model1.add(Conv2D(filters=1,kernel_size=3,input_shape=(10,10,1),activation='relu'))
model1.summary()

示例2

输入:

  • filters = 5
  • kernel_size=(3,3)
  • input_shape=(10,10,1)

如何计算卷积层中的参数个数?

权重:(3,3)= 3 * 3 =9的卷积核

偏置:1

总参数= 9+1 =10

过滤器的总数= 5

卷积核的总参数= 10 * 5=50

如何计算输出形状?

n = 10, f = 3 s = 1, p = 0

默认Stride =1,没有提到填充(所以,p=0)

输出形状= n-f+1 = 10-3 +1 =8

在使用卷积滤波器对输入图像应用卷积后,输出将是一个特征映射。特征图中的通道数量取决于所使用的过滤器的数量。在这个例子中,使用了5个过滤器。因此,特征图中的通道数为5。

因此,特征图的Output_shape = (8,8,5)

model2=keras.models.Sequential()
model2.add(Conv2D(filters=5,kernel_size=3,input_shape=(10,10,1),activation='relu'))
model2.summary()

示例3

输入:

  • filter = 5
  • kernel_size = (3)
  • input_shape = (10 10 3)

如何计算卷积层中的参数个数?

权重:(3,3)= 3*3 =9的卷积核

卷积核将同时卷积所有三个通道(input_image depth=3)。所以一个卷积核的参数是3 * 3 * 3=27
[卷积核大小 * 通道数]

偏置: 1

[每个卷积核加一个偏置]

对于深度3,一个大小为(3,3)的卷积核的总参数=(3 * 3 * 3)+1=28

卷积核的总数= 5。

卷积核的总参数:5个大小为(3,3),input_image depth(3)= 28*5=140

如何计算输出形状?

n = 10, f = 3 s = 1, p = 0

默认Stride =1,没有提到填充(所以,p=0)

输出形状= n-f+1 = 10-3 +1 =8

在使用卷积核对输入图像应用卷积后,输出将是一个特征映射。特征图中的通道数量取决于所使用的过滤器的数量。在这个例子中,使用了5个卷积核。因此,特征图中的通道数为5。

因此,特征图的Output_shape = (8,8,5)

model3=keras.models.Sequential()
model3.add(Conv2D(filters=5,kernel_size=3,input_shape=(10,10,3),activation='relu'))
model3.summary()

作者:Indhumathy Chelliah

Tags:

最近发表
标签列表