网站首页 > 技术文章 正文
图像(二维)卷积示意图
图像(二维)卷积说明
source pixel是原始图像,Convolution Kernel是卷积核,New pixel value是卷积运算后的值。卷积核是一个3*3的窗口,对图像的卷积就是使卷积核与原始图像的像素相对应的位置先做乘法然后把相应位置得到的结果进行求和,最终把求和的结果赋给目标像素。如上图中的计算,最终得到的结果为-8。把原始图像上的点全部进行卷积运算后,就生成了一张新的图像。卷积后,得到的数值越大,说明图像与卷积核特征相似度越高。那么,卷积神经网络训练的目的就是求解最优化的能够代表图像性质的卷积核,也可以称之为网络的权重矩阵。
最大池化运算
最大池化运算体现的是图像中最大值相对位置信息。
建立卷积神经网络
下载数据
这里使用一个猫和狗的图片集,下载地址:http://www.kaggle.com/c/dogs-vs-cats/data,这个网站的注册有些麻烦,也可以通过百度网盘下载:https://pan.baidu.com/s/1wozoz8fbxMF5M-flWNBr_Q。下载完成后图片的分类与筛选过程可以在本教程最后下载完整源码查看。这里我们只对核心的网络构建部分进行说明。
构建网络模型
model = models.Sequential()
#添加一层卷积网络层,32表示输出空间的维度,(3,3)表示卷积核大小
model.add(layers.Conv2D(32, (3,3), activation = 'relu', input_shape = (150,150,3)))
#添加一层最大池化网络层
model.add(layers.MaxPooling2D(2,2))
model.add(layers.Conv2D(64, (3,3), activation = 'relu'))
model.add(layers.MaxPool2D(2,2))
model.add(layers.Conv2D(128, (3,3), activation = 'relu'))
model.add(layers.MaxPool2D(2,2))
model.add(layers.Conv2D(128, (3,3), activation = 'relu'))
model.add(layers.MaxPool2D(2,2))model.add(layers.Flatten())
#添加Dropout正则化
model.add(layers.Dropout(0.5))
#添加两层密集层
model.add(layers.Dense(512, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
model.summary()
网络概况
图像数据预处理
神经网络的输入必须是[0,1]之间,所以需要把图片做一些预处理,并做一些数据的扩展。
#利用这个函数进行预处理,增加图片数据的数量
train_datagen = ImageDataGenerator( rescale=1./255, #像素点范围缩放到[0,1]
rotation_range=40, #图像随机旋转的度数范围
width_shift_range=0.2, #图像在水平方向移动的范围(宽度的20%)
height_shift_range=0.2, #图像在垂直方向移动的范围(高度的20%)
shear_range=0.2, #剪切的强度范围
zoom_range=0.2, #随机缩放的范围
horizontal_flip=True, #随机水平翻转
fill_mode='nearest' #输入边界以外的点的填充方式
)#将train_dir文件夹中的数据用train_datagen方法处理
train_generator = train_datagen.flow_from_directory( train_dir, target_size=(150,150), batch_size=32, class_mode='binary')
小结
本节重点介绍了卷积的含义,并介绍了卷积神经网络的构建过程。完整的源码下载,可。通过本节的学习,你需要了解:
1. 图像卷积的概念;
1. 最大池化的概念;
1. 构建卷积神经网络的方法;
1. keras进行图像预处理的方法。
预训练网络
什么是预训练网络
预训练网络是一个保存好的网络,该网络已经在大型数据集上训练好,由于训练用数据集较大,可以涵盖一些特殊的应用,具有一般意义的通用性。
Keras常用预训练网络
使用VGG16训练模型
实例化VGG16卷积网络
# 实例化VGG16卷积
# weights: 指定模型初始化的权重检查点。
# include_top: 指定模型最后是否连接密集分类器
# input_shape: 输入网络中图片张量的形状
conv_base = VGG16(weights='imagenet', include_top=False, input_shape=(150,150,3))
使用预训练卷积提取特征
datagen = ImageDataGenerator(rescale=1./255)
batch_size = 20
def extract_features(directory, sample_count):
features = np.zeros(shape=(sample_count,4,4,512))
labels = np.zeros(shape=(sample_count))
generator = datagen.flow_from_directory( directory, target_size=(150,150), batch_size=batch_size, class_mode='binary')
i = 0
for input_batch, labels_batch in generator:
features_batch = conv_base.predict(input_batch) #为输入样本计算输出预测 features[i*batch_size:(i+1)*batch_size] = features_batch labels[i*batch_size: (i+1)*batch_size] = labels_batch i += 1 if i * batch_size >= sample_count: break return features, labels
添加密集层网络
model = models.Sequential()
model.add(layers.Dense(256, activation='relu', input_dim=4*4*512))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(1, activation='sigmoid'))
model.compile( optimizer = optimizers.RMSprop(lr = 2e-5), loss='binary_crossentropy', metrics=['acc'])
history = model.fit(train_features, train_labels, epochs=30, batch_size=20, validation_data=(validation_features, validation_labels))
小结
通过本小结的学习,需要能够了解预训练网络的含义与使用方法。在具体的使用过程中需根据具体需求选择预训练网络。
- 上一篇: 用Keras实现深度学习网络
- 下一篇: AI医生诊断肺炎
猜你喜欢
- 2024-11-23 太强了,竟然可以根据指纹图像预测性别
- 2024-11-23 深度残差网络+自适应参数化ReLU(调参记录24)Cifar10~95.80%
- 2024-11-23 从零开始构建:使用CNN和TensorFlow进行人脸特征检测
- 2024-11-23 每个ML从业人员都必须知道的10个TensorFlow技巧
- 2024-11-23 基于OpencvCV的情绪检测
- 2024-11-23 LeNet-5 一个应用于图像分类问题的卷积神经网络
- 2024-11-23 使用TensorBoard进行超参数优化
- 2024-11-23 如何实现CNN特征层可视化?终于懂了....
- 2024-11-23 计算卷积神经网络参数总数和输出形状
- 2024-11-23 使用卷积神经网络和 Python 进行图像分类
- 最近发表
- 标签列表
-
- cmd/c (57)
- c++中::是什么意思 (57)
- sqlset (59)
- ps可以打开pdf格式吗 (58)
- phprequire_once (61)
- localstorage.removeitem (74)
- routermode (59)
- vector线程安全吗 (70)
- & (66)
- java (73)
- org.redisson (64)
- log.warn (60)
- cannotinstantiatethetype (62)
- js数组插入 (83)
- resttemplateokhttp (59)
- gormwherein (64)
- linux删除一个文件夹 (65)
- mac安装java (72)
- reader.onload (61)
- outofmemoryerror是什么意思 (64)
- flask文件上传 (63)
- eacces (67)
- 查看mysql是否启动 (70)
- java是值传递还是引用传递 (58)
- 无效的列索引 (74)