网站首页 > 技术文章 正文
什么是残差神经网络?
原则上,神经网络的层数越多,应获得越好的结果。一个更深层的网络可以学到任何浅层的东西,甚至可能更多。如果对于给定的数据集,网络无法通过添加更多的层来学习更多东西,那么它就可以学习这些其他层的恒等映射(identity mappings)。这样,它可以保留先前层中的信息,并且不会比较浅的层更糟糕。
但是,实际上情况并非如此。越深的网络越难优化。随着我们向网络中添加层,我们在训练过程中的难度也会增加;用于查找正确参数的优化算法也会变得越来越困难。随着我们添加更多层,网络将获得更好的结果(直到某个时候为止)。然后,随着我们继续添加额外的层,准确性开始下降。
残差网络试图通过添加所谓的skip connections来解决此问题。如前所述,更深层的网络至少应该能够学习恒等映射(identity mappings)。skip connections是这样做的:它们从网络中的一个点到另一点添加恒等映射,然后让网络仅学习额外的()。如果网络没有其他可以学习的东西,那么它仅将()设为0。事实证明,对于网络来说,学习一个更接近于0的映射比学习恒等映射更容易。
具有skip connection的块称为残差块,而残差神经网络(ResNet)只是这些块的连接。
Keras Functional API简介
可能您已经熟悉了Sequential类,它可以让一个人很容易地构建一个神经网络,只要把层一个接一个地堆叠起来,就像这样:
但是,这种构建神经网络的方式不足以满足我们的需求。使用Sequential类,我们无法添加skip connections。Keras的Model类可与Functional API一起使用,以创建用于构建更复杂的网络体系结构的层。构造后,keras.layers.Input返回张量对象。Keras中的层对象也可以像函数一样使用,以张量对象作为参数来调用它。返回的对象是张量,然后可以将其作为输入传递到另一层,依此类推。
举个例子:
这种语法的真正用途是在使用所谓的“ Merge”层时,通过该层可以合并更多输入张量。这些层中的一些例子是:Add,Subtract,Multiply,Average。我们在构建剩余块时需要的是Add。
使用的Add示例:
ResNet的Python实现
接下来,我们将实现一个ResNet和其普通(无skip connections)副本,以进行比较。
我们将在此处构建的ResNet具有以下结构:
- 形状为(32,32,3)的输入
- 1个Conv2D层,64个filters
- 2、5、5、2残差块的filters分别为64、128、256和512
- 池大小= 4的AveragePooling2D层
- Flatten层
- 10个输出节点的Dense层
它共有30个conv+dense层。所有的核大小都是3x3。我们在conv层之后使用ReLU激活和BatchNormalization。我们首先创建一个辅助函数,将张量作为输入并为其添加relu和批归一化:
然后,我们创建一个用于构造残差块的函数。
create_res_net()函数将所有内容组合在一起。这是完整的代码:
from tensorflow import Tensor
from tensorflow.keras.layers import Input, Conv2D, ReLU, BatchNormalization,\
Add, AveragePooling2D, Flatten, Dense
from tensorflow.keras.models import Model
def relu_bn(inputs: Tensor) -> Tensor:
relu = ReLU()(inputs)
bn = BatchNormalization()(relu)
return bn
def residual_block(x: Tensor, downsample: bool, filters: int, kernel_size: int = 3) -> Tensor:
y = Conv2D(kernel_size=kernel_size,
strides= (1 if not downsample else 2),
filters=filters,
padding="same")(x)
y = relu_bn(y)
y = Conv2D(kernel_size=kernel_size,
strides=1,
filters=filters,
padding="same")(y)
if downsample:
x = Conv2D(kernel_size=1,
strides=2,
filters=filters,
padding="same")(x)
out = Add()([x, y])
out = relu_bn(out)
return out
def create_res_net():
inputs = Input(shape=(32, 32, 3))
num_filters = 64
t = BatchNormalization()(inputs)
t = Conv2D(kernel_size=3,
strides=1,
filters=num_filters,
padding="same")(t)
t = relu_bn(t)
num_blocks_list = [2, 5, 5, 2]
for i in range(len(num_blocks_list)):
num_blocks = num_blocks_list[i]
for j in range(num_blocks):
t = residual_block(t, downsample=(j==0 and i!=0), filters=num_filters)
num_filters *= 2
t = AveragePooling2D(4)(t)
t = Flatten()(t)
outputs = Dense(10, activation='softmax')(t)
model = Model(inputs, outputs)
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
return model
普通网络以类似的方式构建,但它没有skip connections,我们也不使用residual_block()帮助函数;一切都在create_plain_net()中完成。plain network的Python代码如下:
from tensorflow import Tensor
from tensorflow.keras.layers import Input, Conv2D, ReLU, BatchNormalization,\
AveragePooling2D, Flatten, Dense
from tensorflow.keras.models import Model
def relu_bn(inputs: Tensor) -> Tensor:
relu = ReLU()(inputs)
bn = BatchNormalization()(relu)
return bn
def create_plain_net():
inputs = Input(shape=(32, 32, 3))
num_filters = 64
t = BatchNormalization()(inputs)
t = Conv2D(kernel_size=3,
strides=1,
filters=num_filters,
padding="same")(t)
t = relu_bn(t)
num_blocks_list = [4, 10, 10, 4]
for i in range(len(num_blocks_list)):
num_blocks = num_blocks_list[i]
for j in range(num_blocks):
downsample = (j==0 and i!=0)
t = Conv2D(kernel_size=3,
strides= (1 if not downsample else 2),
filters=num_filters,
padding="same")(t)
t = relu_bn(t)
num_filters *= 2
t = AveragePooling2D(4)(t)
t = Flatten()(t)
outputs = Dense(10, activation='softmax')(t)
model = Model(inputs, outputs)
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
return model
训练CIFAR-10并查看结果
CIFAR-10是一个包含10个类别的32x32 rgb图像的机器学习数据集。它包含了50k的训练图像和10k的测试图像。
以下是来自每个类别的10张随机图片样本:
我们将在这个机器学习数据集上对ResNet和PlainNet进行20个epoch的训练,然后比较结果。
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
import datetime
import os
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
model = create_res_net() # or create_plain_net()
model.summary()
timestr = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
name = 'cifar-10_res_net_30-'+timestr # or 'cifar-10_plain_net_30-'+timestr
checkpoint_path = "checkpoints/"+name+"/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
os.system('mkdir {}'.format(checkpoint_dir))
# save model after each epoch
cp_callback = ModelCheckpoint(
filepath=checkpoint_path,
verbose=1
)
tensorboard_callback = TensorBoard(
log_dir='tensorboard_logs/'+name,
histogram_freq=1
)
model.fit(
x=x_train,
y=y_train,
epochs=10,
verbose=1,
validation_data=(x_test, y_test),
batch_size=128
)
ResNet和PlainNet在训练时间上没有显著差异。我们得到的结果如下所示。
因此,通过在该机器学习数据集上使用ResNet ,我们将验证准确性提高了1.59%。在更深层的网络上,差异应该更大。
猜你喜欢
- 2024-11-23 太强了,竟然可以根据指纹图像预测性别
- 2024-11-23 深度残差网络+自适应参数化ReLU(调参记录24)Cifar10~95.80%
- 2024-11-23 从零开始构建:使用CNN和TensorFlow进行人脸特征检测
- 2024-11-23 每个ML从业人员都必须知道的10个TensorFlow技巧
- 2024-11-23 基于OpencvCV的情绪检测
- 2024-11-23 LeNet-5 一个应用于图像分类问题的卷积神经网络
- 2024-11-23 使用TensorBoard进行超参数优化
- 2024-11-23 如何实现CNN特征层可视化?终于懂了....
- 2024-11-23 计算卷积神经网络参数总数和输出形状
- 2024-11-23 使用卷积神经网络和 Python 进行图像分类
- 1507℃桌面软件开发新体验!用 Blazor Hybrid 打造简洁高效的视频处理工具
- 505℃Dify工具使用全场景:dify-sandbox沙盒的原理(源码篇·第2期)
- 485℃MySQL service启动脚本浅析(r12笔记第59天)
- 465℃服务器异常重启,导致mysql启动失败,问题解决过程记录
- 462℃启用MySQL查询缓存(mysql8.0查询缓存)
- 443℃「赵强老师」MySQL的闪回(赵强iso是哪个大学毕业的)
- 422℃mysql服务怎么启动和关闭?(mysql服务怎么启动和关闭)
- 418℃MySQL server PID file could not be found!失败
- 最近发表
-
- netty系列之:搭建HTTP上传文件服务器
- 让deepseek教我将deepseek接入word
- 前端大文件分片上传断点续传(前端大文件分片上传断点续传怎么操作)
- POST 为什么会发送两次请求?(post+为什么会发送两次请求?怎么回答)
- Jmeter之HTTP请求与响应(jmeter运行http请求没反应)
- WAF-Bypass之SQL注入绕过思路总结
- 用户疯狂点击上传按钮,如何确保只有一个上传任务在执行?
- 二 计算机网络 前端学习 物理层 链路层 网络层 传输层 应用层 HTTP
- HTTP请求的完全过程(http请求的基本过程)
- dart系列之:浏览器中的舞者,用dart发送HTTP请求
- 标签列表
-
- c++中::是什么意思 (83)
- 标签用于 (65)
- 主键只能有一个吗 (66)
- c#console.writeline不显示 (75)
- pythoncase语句 (81)
- es6includes (73)
- windowsscripthost (67)
- apt-getinstall-y (86)
- node_modules怎么生成 (76)
- chromepost (65)
- c++int转char (75)
- static函数和普通函数 (76)
- el-date-picker开始日期早于结束日期 (70)
- js判断是否是json字符串 (67)
- checkout-b (67)
- localstorage.removeitem (74)
- vector线程安全吗 (70)
- & (66)
- java (73)
- js数组插入 (83)
- linux删除一个文件夹 (65)
- mac安装java (72)
- eacces (67)
- 查看mysql是否启动 (70)
- 无效的列索引 (74)