优秀的编程知识分享平台

网站首页 > 技术文章 正文

如何使用Keras构建残差神经网络?

nanyue 2024-11-23 20:16:57 技术文章 1 ℃


什么是残差神经网络?

原则上,神经网络的层数越多,应获得越好的结果。一个更深层的网络可以学到任何浅层的东西,甚至可能更多。如果对于给定的数据集,网络无法通过添加更多的层来学习更多东西,那么它就可以学习这些其他层的恒等映射(identity mappings)。这样,它可以保留先前层中的信息,并且不会比较浅的层更糟糕。

但是,实际上情况并非如此。越深的网络越难优化。随着我们向网络中添加层,我们在训练过程中的难度也会增加;用于查找正确参数的优化算法也会变得越来越困难。随着我们添加更多层,网络将获得更好的结果(直到某个时候为止)。然后,随着我们继续添加额外的层,准确性开始下降。

残差网络试图通过添加所谓的skip connections来解决此问题。如前所述,更深层的网络至少应该能够学习恒等映射(identity mappings)。skip connections是这样做的:它们从网络中的一个点到另一点添加恒等映射,然后让网络仅学习额外的()。如果网络没有其他可以学习的东西,那么它仅将()设为0。事实证明,对于网络来说,学习一个更接近于0的映射比学习恒等映射更容易。

具有skip connection的块称为残差块,而残差神经网络(ResNet)只是这些块的连接。

Keras Functional API简介

可能您已经熟悉了Sequential类,它可以让一个人很容易地构建一个神经网络,只要把层一个接一个地堆叠起来,就像这样:

但是,这种构建神经网络的方式不足以满足我们的需求。使用Sequential类,我们无法添加skip connections。Keras的Model类可与Functional API一起使用,以创建用于构建更复杂的网络体系结构的层。构造后,keras.layers.Input返回张量对象。Keras中的层对象也可以像函数一样使用,以张量对象作为参数来调用它。返回的对象是张量,然后可以将其作为输入传递到另一层,依此类推。

举个例子:

这种语法的真正用途是在使用所谓的“ Merge”层时,通过该层可以合并更多输入张量。这些层中的一些例子是:Add,Subtract,Multiply,Average。我们在构建剩余块时需要的是Add。

使用的Add示例:

ResNet的Python实现

接下来,我们将实现一个ResNet和其普通(无skip connections)副本,以进行比较。

我们将在此处构建的ResNet具有以下结构:

  • 形状为(32,32,3)的输入
  • 1个Conv2D层,64个filters
  • 2、5、5、2残差块的filters分别为64、128、256和512
  • 池大小= 4的AveragePooling2D层
  • Flatten层
  • 10个输出节点的Dense层

它共有30个conv+dense层。所有的核大小都是3x3。我们在conv层之后使用ReLU激活和BatchNormalization。我们首先创建一个辅助函数,将张量作为输入并为其添加relu和批归一化:

然后,我们创建一个用于构造残差块的函数。

create_res_net()函数将所有内容组合在一起。这是完整的代码:

from tensorflow import Tensor
from tensorflow.keras.layers import Input, Conv2D, ReLU, BatchNormalization,\
                                    Add, AveragePooling2D, Flatten, Dense
from tensorflow.keras.models import Model

def relu_bn(inputs: Tensor) -> Tensor:
    relu = ReLU()(inputs)
    bn = BatchNormalization()(relu)
    return bn

def residual_block(x: Tensor, downsample: bool, filters: int, kernel_size: int = 3) -> Tensor:
    y = Conv2D(kernel_size=kernel_size,
               strides= (1 if not downsample else 2),
               filters=filters,
               padding="same")(x)
    y = relu_bn(y)
    y = Conv2D(kernel_size=kernel_size,
               strides=1,
               filters=filters,
               padding="same")(y)

    if downsample:
        x = Conv2D(kernel_size=1,
                   strides=2,
                   filters=filters,
                   padding="same")(x)
    out = Add()([x, y])
    out = relu_bn(out)
    return out

def create_res_net():
    
    inputs = Input(shape=(32, 32, 3))
    num_filters = 64
    
    t = BatchNormalization()(inputs)
    t = Conv2D(kernel_size=3,
               strides=1,
               filters=num_filters,
               padding="same")(t)
    t = relu_bn(t)
    
    num_blocks_list = [2, 5, 5, 2]
    for i in range(len(num_blocks_list)):
        num_blocks = num_blocks_list[i]
        for j in range(num_blocks):
            t = residual_block(t, downsample=(j==0 and i!=0), filters=num_filters)
        num_filters *= 2
    
    t = AveragePooling2D(4)(t)
    t = Flatten()(t)
    outputs = Dense(10, activation='softmax')(t)
    
    model = Model(inputs, outputs)

    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    return model

普通网络以类似的方式构建,但它没有skip connections,我们也不使用residual_block()帮助函数;一切都在create_plain_net()中完成。plain network的Python代码如下:

from tensorflow import Tensor
from tensorflow.keras.layers import Input, Conv2D, ReLU, BatchNormalization,\
                                    AveragePooling2D, Flatten, Dense
from tensorflow.keras.models import Model

def relu_bn(inputs: Tensor) -> Tensor:
    relu = ReLU()(inputs)
    bn = BatchNormalization()(relu)
    return bn

def create_plain_net():
    
    inputs = Input(shape=(32, 32, 3))
    num_filters = 64
    
    t = BatchNormalization()(inputs)
    t = Conv2D(kernel_size=3,
               strides=1,
               filters=num_filters,
               padding="same")(t)
    t = relu_bn(t)
    
    num_blocks_list = [4, 10, 10, 4]
    for i in range(len(num_blocks_list)):
        num_blocks = num_blocks_list[i]
        for j in range(num_blocks):
            downsample = (j==0 and i!=0)
            t = Conv2D(kernel_size=3,
                       strides= (1 if not downsample else 2),
                       filters=num_filters,
                       padding="same")(t)
            t = relu_bn(t)
        num_filters *= 2
    
    t = AveragePooling2D(4)(t)
    t = Flatten()(t)
    outputs = Dense(10, activation='softmax')(t)
    
    model = Model(inputs, outputs)

    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    return model

训练CIFAR-10并查看结果

CIFAR-10是一个包含10个类别的32x32 rgb图像的机器学习数据集。它包含了50k的训练图像和10k的测试图像。

以下是来自每个类别的10张随机图片样本:

我们将在这个机器学习数据集上对ResNet和PlainNet进行20个epoch的训练,然后比较结果。

from tensorflow.keras.datasets import cifar10
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
import datetime
import os

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

model = create_res_net() # or create_plain_net()
model.summary()

timestr = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
name = 'cifar-10_res_net_30-'+timestr # or 'cifar-10_plain_net_30-'+timestr

checkpoint_path = "checkpoints/"+name+"/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
os.system('mkdir {}'.format(checkpoint_dir))

# save model after each epoch
cp_callback = ModelCheckpoint(
    filepath=checkpoint_path,
    verbose=1
)
tensorboard_callback = TensorBoard(
    log_dir='tensorboard_logs/'+name,
    histogram_freq=1
)

model.fit(
    x=x_train,
    y=y_train,
    epochs=10,
    verbose=1,
    validation_data=(x_test, y_test),
    batch_size=128
)

ResNet和PlainNet在训练时间上没有显著差异。我们得到的结果如下所示。


因此,通过在该机器学习数据集上使用ResNet ,我们将验证准确性提高了1.59%。在更深层的网络上,差异应该更大。

Tags:

最近发表
标签列表