网站首页 > 技术文章 正文
使用Keras实现Generative Adversarial Network(GAN)模型来生成MNIST数字图像的步骤如下:
1)导入所需的库:
import numpy as np
from keras.datasets import mnist
from keras.models import Sequential, Model
from keras.layers import Dense, Flatten, Reshape
from keras.layers import Conv2D, Conv2DTranspose
from keras.layers import LeakyReLU, Dropout
from keras.optimizers import Adam
2)加载MNIST数据集:
(X_train, _), (_, _) = mnist.load_data()
# 对数据做归一化和重新调整形状
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
3)定义生成器(Generator)和判别器(Discriminator)模型:
def build_generator():
model = Sequential()
model.add(Dense(128 * 7 * 7, activation="relu", input_dim=100))
model.add(Reshape((7, 7, 128)))
model.add(Conv2DTranspose(64, kernel_size=4, strides=2, padding="same", activation="relu"))
model.add(Conv2DTranspose(1, kernel_size=4, strides=2, padding="same", activation="tanh"))
model.compile(loss="binary_crossentropy", optimizer=Adam(lr=0.0002, beta_1=0.5))
return model
def build_discriminator():
model = Sequential()
model.add(Conv2D(64, kernel_size=4, strides=2, padding="same", input_shape=(28, 28, 1)))
model.add(LeakyReLU(0.2))
model.add(Conv2D(128, kernel_size=4, strides=2, padding="same"))
model.add(LeakyReLU(0.2))
model.add(Flatten())
model.add(Dense(1, activation="sigmoid"))
model.compile(loss="binary_crossentropy", optimizer=Adam(lr=0.0002, beta_1=0.5))
return model
# 实例化生成器和判别器模型
generator = build_generator()
discriminator = build_discriminator()
4)定义GAN模型:
def build_gan(generator, discriminator):
discriminator.trainable = False
gan_input = Input(shape=(100,))
gan_output = discriminator(generator(gan_input))
gan = Model(gan_input, gan_output)
gan.compile(loss="binary_crossentropy", optimizer=Adam(lr=0.0002, beta_1=0.5))
return gan
# 实例化GAN模型
gan = build_gan(generator, discriminator)
5)训练GAN模型:
def train_gan(gan, generator, discriminator, X_train, epochs=50, batch_size=128, sample_interval=200):
for epoch in range(epochs):
# 随机选择一批真实图像
idx = np.random.randint(0, X_train.shape[0], batch_size)
real_images = X_train[idx]
# 生成一批噪声作为输入
noise = np.random.normal(0, 1, (batch_size, 100))
# 生成假图像
generated_images = generator.predict(noise)
# 训练判别器
discriminator_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
discriminator_loss_fake = discriminator.train_on_batch(generated_images, np.zeros((batch_size, 1)))
discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)
# 训练生成器
noise = np.random.normal(0, 1, (batch_size, 100))
generator_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))
# 输出训练过程中的损失
if epoch % sample_interval == 0:
print(f"Epoch {epoch}: discriminator loss = {discriminator_loss}, generator loss = {generator_loss}")
# 保存生成的图像
sample_images(generator, epoch)
# 定义保存生成图像的函数
def sample_images(generator, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, 100))
generated_images = generator.predict(noise)
generated_images = 0.5 * generated_images + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(generated_images[cnt, :, :, 0], cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig(f"images/mnist_{epoch}.png")
plt.close()
# 开始训练GAN模型
train_gan(gan, generator, discriminator, X_train)
通过以上步骤,你可以使用Keras实现一个简单的GAN模型来生成MNIST数字图像。训练过程中,生成器和判别器模型会相互竞争,生成器尝试生成接近真实图像的假图像,而判别器则尝试区分真实图像和假图像。随着训练的进行,生成器会逐渐学习生成逼真的图像。
猜你喜欢
- 2024-11-23 太强了,竟然可以根据指纹图像预测性别
- 2024-11-23 深度残差网络+自适应参数化ReLU(调参记录24)Cifar10~95.80%
- 2024-11-23 从零开始构建:使用CNN和TensorFlow进行人脸特征检测
- 2024-11-23 每个ML从业人员都必须知道的10个TensorFlow技巧
- 2024-11-23 基于OpencvCV的情绪检测
- 2024-11-23 LeNet-5 一个应用于图像分类问题的卷积神经网络
- 2024-11-23 使用TensorBoard进行超参数优化
- 2024-11-23 如何实现CNN特征层可视化?终于懂了....
- 2024-11-23 计算卷积神经网络参数总数和输出形状
- 2024-11-23 使用卷积神经网络和 Python 进行图像分类
- 02-21走进git时代, 你该怎么玩?_gits
- 02-21GitHub是什么?它可不仅仅是云中的Git版本控制器
- 02-21Git常用操作总结_git基本用法
- 02-21为什么互联网巨头使用Git而放弃SVN?(含核心命令与原理)
- 02-21Git 高级用法,喜欢就拿去用_git基本用法
- 02-21Git常用命令和Git团队使用规范指南
- 02-21总结几个常用的Git命令的使用方法
- 02-21Git工作原理和常用指令_git原理详解
- 最近发表
- 标签列表
-
- cmd/c (57)
- c++中::是什么意思 (57)
- sqlset (59)
- ps可以打开pdf格式吗 (58)
- phprequire_once (61)
- localstorage.removeitem (74)
- routermode (59)
- vector线程安全吗 (70)
- & (66)
- java (73)
- org.redisson (64)
- log.warn (60)
- cannotinstantiatethetype (62)
- js数组插入 (83)
- resttemplateokhttp (59)
- gormwherein (64)
- linux删除一个文件夹 (65)
- mac安装java (72)
- reader.onload (61)
- outofmemoryerror是什么意思 (64)
- flask文件上传 (63)
- eacces (67)
- 查看mysql是否启动 (70)
- java是值传递还是引用传递 (58)
- 无效的列索引 (74)