优秀的编程知识分享平台

网站首页 > 技术文章 正文

智造讲堂:深度学习框架之Keras

nanyue 2024-11-23 20:17:10 技术文章 1 ℃

引自:《深度学习》(作者:文龙, 李新宇)


Keras是深度学习领域另一个常用框架。它以TensorFlow、Theano、CNTK作为后端引擎运行,提供直观而简洁的API,即使是非专业人员,也可以在各自领域轻松使用和开发深度学习模型,如多重感知机、卷积神经网络、循环神经网络以及各种复杂的网络模型(Keras官网为https://keras.io/)。


目前,Keras具有广阔的应用场景,具有良好的模块化设计、用户友好的接口规范等特点。它将大量重复的工作进行抽象并形成接口,使得用户只需采用少量代码完成接口部分即可实现深度学习模型的快速搭建,以节约模型构建时间。同时,Keras支持CPU和GPU的无缝运行,也支持多GPU并行计算。目前,Keras的API已经被TensorFlow借鉴,形成了TensorFlow下的Keras模块。


本案例采用TensorFlow下Keras模块搭建一个CNN模型,并将其应用在Cifar10数据集上。Cifar10数据集是著名的图像分类数据集,包含10个类别,每个类别6000张32x32的彩色图像,总数据量为60000张图像。其中50000张为训练集图像,另外10000张为测试图像。Cifar10数据集下载地址为http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz。但是,Keras已经将Cifar10的下载和使用进行了封装,只需要调用相关函数即可。本案例的操作如下:


(1)导入相关库。如图1所示,本实例代码共导入了TensorFlow库和tensorflow.keras库下的datasets(数据集)、layers(网络层)和models(模型)的相关库。最后导入了绘图软件matplotlib中的绘图模块pyplot,并将其命名为plt。


图1 导入Keras相关库


(2)准备数据集。下载并准备Cifar10数据集。tensorflow.keras中Cifar10的相关接口在datasets类下,因此直接导入即可,如图2中的代码第8行所示。第11行代码为对其进行预处理。第14~26行为展示其中前25个样本,图像输出结果如图3所示。


图2 导入Cifar10数据集


图3 Cifar10数据集中前25幅图像


(3)构造CNN模型。在tensorflow.keras中,卷积层的构造函数为Conv2D,池化层为MaxPooling2D,全连接层为Dense函数。如图4所示,第30行定义了一个模型(model),第31行构建了第一个卷积层,其参数分别表示卷积层的深度32、卷积核的大小3*3、激活函数Relu。其中input_shape表明其输入的tensor维度,为(32, 32, 3)。CNN模型的输入Tensor的形式为(image_height, image_width, color_channels),分别包含了图像高度、宽度及颜色信息。如果不熟悉图像处理,颜色信息建议使用RGB色彩模式,此模式下,color_channels为(R,G,B)分别对应RGB的三个颜色通道(color_channels)。由于Cifar10数据集是32x32的彩色图像,所以其数据形式正好为(32, 32, 3)。


图4 构造CNN模型


第32行构建了一个最大池化层,池化层的参数为2*2。第33~35行定义了交替的卷积层和池化层,其意义不在重复。


第38行调用Flatten()函数将3维张量转换为1维,之后传入一个或者多个Dense层中。本例中第39行和第40行分别定义了两个Dense层。在最后的Dense层中,添加了softmax激活函数,实现对Cifar10数据集中10个类别的预测输出。


第43层展示了该CNN模型各层的参数状况,其结果如图5所示。图中Layer(type)分别表示层名称、层类型,OutputShape表示数据张量在图中的维度变化,Param表示参数量。由此可见,MaxPooling2D会将数据张量的宽度和高度信息降低一半。而Flatten层将4*4*64的三维张量转化为1024的一维张量。


图5 CNN模型各层的参数状况


(4)编译并训练模型。在本例中,由于Cifar10数据集的label是数字编码,故采用SparseCategoricalCrossentropy函数。对于独热编码的情况,可以直接采用CategoricalCrossentropy函数。在本例中,训练的优化器选择Adam优化器,训练中采用的度量指标为准确率(accuracy)。


模型的编译如图6的第46行所示,分别指定模型的优化器、损失函数和度量指标。第50行表示模型的训练,其中train_images和train_labels分别为训练集样本和训练集的label。参数epoches指定了训练次数。此处并未指定批量,故默认为每批使用全部样本。参数validation_data表示验证集。该模型的训练结果如图7所示。从结果上看,经过10步的训练,该CNN模型的预测精度已经达到74.50%。


图6 编译并训练模型


图7 CNN模型在Cifar10数据集上的训练结果

Tags:

最近发表
标签列表