网站首页 > 技术文章 正文
LeNet-5是一个应用于图像分类问题的卷积神经网络,其学习目标是从一系列由32×32×1灰度图像表示的手写数字中识别和区分0-9。LeNet-5的隐含层由2个卷积层、2个池化层构筑和2个全连接层组成:
构建方式:
- (3×3)×1×6的卷积层(步长为1,无填充),2×2均值池化(步长为2,无填充),tanh激励函数
- (5×5)×6×16的卷积层(步长为1,无填充),2×2均值池化(步长为2,无填充),tanh激励函数
- 2个全连接层,神经元数量为120和84
LeNet-5特点:
1、LeNet-5规模很小:从现代深度学习的观点来看,LeNet-5规模很小,但考虑LeCun et al. (1998)的数值计算条件,LeNet-5在该时期仍具有相当的复杂度 。
2、LeNet-5使用双曲正切函数作为激励函数。
3、使用均方差(Mean Squared Error, MSE)作为误差函数并对卷积操作进行了修改以减少计算开销,这些设置在随后的卷积神经网络算法中已被更优化的方法取代。
TensorFlow和Keras的代码范例
在现代机器学习库的范式下,LeNet-5是一个易于实现的算法,这里提供一个使用TensorFlow和Keras的计算例子:
# 导入模块
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
# 读取MNIST数据
mnist = keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
# 重构数据至4维(样本,像素X,像素Y,通道)
x_train=x_train.reshape(x_train.shape+(1,))
x_test=x_test.reshape(x_test.shape+(1,))
x_train, x_test = x_train/255.0, x_test/255.0
# 数据标签
label_train = keras.utils.to_categorical(y_train, 10)
label_test = keras.utils.to_categorical(y_test, 10)
# LeNet-5构筑
model = keras.Sequential([
keras.layers.Conv2D(6, kernel_size=(3, 3), strides=(1, 1), activation='tanh', padding='valid'),
keras.layers.AveragePooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'),
keras.layers.Conv2D(16, kernel_size=(5, 5), strides=(1, 1), activation='tanh', padding='valid'),
keras.layers.AveragePooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'),
keras.layers.Flatten(),
keras.layers.Dense(120, activation='tanh'),
keras.layers.Dense(84, activation='tanh'),
keras.layers.Dense(10, activation='softmax'),
])
# 使用SGD编译模型
model.compile(loss=keras.losses.categorical_crossentropy, optimizer='SGD')
# 学习30个纪元(可依据CPU计算力调整),使用20%数据交叉验证
records = model.fit(x_train, label_train, epochs=20, validation_split=0.2)
# 预测
y_pred = np.argmax(model.predict(x_test), axis=1)
print("prediction accuracy: {}".format(sum(y_pred==y_test)/len(y_test)))
# 绘制结果
plt.plot(records.history['loss'],label='training set loss')
plt.plot(records.history['val_loss'],label='validation set loss')
plt.ylabel('categorical cross-entropy'); plt.xlabel('epoch')
plt.legend()
该例子使用MNIST数据代替LeCun et al. (1998)的原始数据,使用交叉熵(categorical cross-entropy)作为损失函数。
- 上一篇: 如何实现CNN特征层可视化?终于懂了....
- 下一篇: 基于OpencvCV的情绪检测
猜你喜欢
- 2024-11-23 太强了,竟然可以根据指纹图像预测性别
- 2024-11-23 深度残差网络+自适应参数化ReLU(调参记录24)Cifar10~95.80%
- 2024-11-23 从零开始构建:使用CNN和TensorFlow进行人脸特征检测
- 2024-11-23 每个ML从业人员都必须知道的10个TensorFlow技巧
- 2024-11-23 基于OpencvCV的情绪检测
- 2024-11-23 使用TensorBoard进行超参数优化
- 2024-11-23 如何实现CNN特征层可视化?终于懂了....
- 2024-11-23 计算卷积神经网络参数总数和输出形状
- 2024-11-23 使用卷积神经网络和 Python 进行图像分类
- 2024-11-23 用于图像降噪的卷积自编码器
- 1509℃桌面软件开发新体验!用 Blazor Hybrid 打造简洁高效的视频处理工具
- 527℃Dify工具使用全场景:dify-sandbox沙盒的原理(源码篇·第2期)
- 492℃MySQL service启动脚本浅析(r12笔记第59天)
- 472℃服务器异常重启,导致mysql启动失败,问题解决过程记录
- 469℃启用MySQL查询缓存(mysql8.0查询缓存)
- 450℃「赵强老师」MySQL的闪回(赵强iso是哪个大学毕业的)
- 429℃mysql服务怎么启动和关闭?(mysql服务怎么启动和关闭)
- 426℃MySQL server PID file could not be found!失败
- 最近发表
- 标签列表
-
- c++中::是什么意思 (83)
- 标签用于 (65)
- 主键只能有一个吗 (66)
- c#console.writeline不显示 (75)
- pythoncase语句 (81)
- es6includes (73)
- windowsscripthost (67)
- apt-getinstall-y (86)
- node_modules怎么生成 (76)
- chromepost (65)
- c++int转char (75)
- static函数和普通函数 (76)
- el-date-picker开始日期早于结束日期 (70)
- js判断是否是json字符串 (67)
- checkout-b (67)
- localstorage.removeitem (74)
- vector线程安全吗 (70)
- & (66)
- java (73)
- js数组插入 (83)
- linux删除一个文件夹 (65)
- mac安装java (72)
- eacces (67)
- 查看mysql是否启动 (70)
- 无效的列索引 (74)