优秀的编程知识分享平台

网站首页 > 技术文章 正文

LeNet-5 一个应用于图像分类问题的卷积神经网络

nanyue 2024-11-23 20:18:46 技术文章 2 ℃

LeNet-5是一个应用于图像分类问题的卷积神经网络,其学习目标是从一系列由32×32×1灰度图像表示的手写数字中识别和区分0-9。LeNet-5的隐含层由2个卷积层、2个池化层构筑和2个全连接层组成:

构建方式:

  1. (3×3)×1×6的卷积层(步长为1,无填充),2×2均值池化(步长为2,无填充),tanh激励函数
  2. (5×5)×6×16的卷积层(步长为1,无填充),2×2均值池化(步长为2,无填充),tanh激励函数
  3. 2个全连接层,神经元数量为120和84

LeNet-5特点:

1、LeNet-5规模很小:从现代深度学习的观点来看,LeNet-5规模很小,但考虑LeCun et al. (1998)的数值计算条件,LeNet-5在该时期仍具有相当的复杂度 。

2、LeNet-5使用双曲正切函数作为激励函数。

3、使用均方差(Mean Squared Error, MSE)作为误差函数并对卷积操作进行了修改以减少计算开销,这些设置在随后的卷积神经网络算法中已被更优化的方法取代。

TensorFlow和Keras的代码范例

在现代机器学习库的范式下,LeNet-5是一个易于实现的算法,这里提供一个使用TensorFlow和Keras的计算例子:

# 导入模块
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
# 读取MNIST数据
mnist = keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
# 重构数据至4维(样本,像素X,像素Y,通道)
x_train=x_train.reshape(x_train.shape+(1,))
x_test=x_test.reshape(x_test.shape+(1,))
x_train, x_test = x_train/255.0, x_test/255.0
# 数据标签
label_train = keras.utils.to_categorical(y_train, 10)
label_test = keras.utils.to_categorical(y_test, 10)
# LeNet-5构筑
model = keras.Sequential([
keras.layers.Conv2D(6, kernel_size=(3, 3), strides=(1, 1), activation='tanh', padding='valid'),
keras.layers.AveragePooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'),
keras.layers.Conv2D(16, kernel_size=(5, 5), strides=(1, 1), activation='tanh', padding='valid'),
keras.layers.AveragePooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'),
keras.layers.Flatten(),
keras.layers.Dense(120, activation='tanh'),
keras.layers.Dense(84, activation='tanh'),
keras.layers.Dense(10, activation='softmax'),
 ])
# 使用SGD编译模型
model.compile(loss=keras.losses.categorical_crossentropy, optimizer='SGD')
# 学习30个纪元(可依据CPU计算力调整),使用20%数据交叉验证
records = model.fit(x_train, label_train, epochs=20, validation_split=0.2)
# 预测
y_pred = np.argmax(model.predict(x_test), axis=1)
print("prediction accuracy: {}".format(sum(y_pred==y_test)/len(y_test)))
# 绘制结果
plt.plot(records.history['loss'],label='training set loss')
plt.plot(records.history['val_loss'],label='validation set loss')
plt.ylabel('categorical cross-entropy'); plt.xlabel('epoch')
plt.legend()

该例子使用MNIST数据代替LeCun et al. (1998)的原始数据,使用交叉熵(categorical cross-entropy)作为损失函数。

Tags:

最近发表
标签列表