优秀的编程知识分享平台

网站首页 > 技术文章 正文

如何使用Keras实现GAC模型生成MNIST数字图像

nanyue 2024-11-23 20:17:22 技术文章 3 ℃

#头条创作挑战赛#

使用Keras实现Generative Adversarial Network(GAN)模型来生成MNIST数字图像的步骤如下:

1)导入所需的库:

import numpy as np
from keras.datasets import mnist
from keras.models import Sequential, Model
from keras.layers import Dense, Flatten, Reshape
from keras.layers import Conv2D, Conv2DTranspose
from keras.layers import LeakyReLU, Dropout
from keras.optimizers import Adam

2)加载MNIST数据集:

(X_train, _), (_, _) = mnist.load_data()

# 对数据做归一化和重新调整形状
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)

3)定义生成器(Generator)和判别器(Discriminator)模型:

def build_generator():
    model = Sequential()
    model.add(Dense(128 * 7 * 7, activation="relu", input_dim=100))
    model.add(Reshape((7, 7, 128)))
    model.add(Conv2DTranspose(64, kernel_size=4, strides=2, padding="same", activation="relu"))
    model.add(Conv2DTranspose(1, kernel_size=4, strides=2, padding="same", activation="tanh"))
    model.compile(loss="binary_crossentropy", optimizer=Adam(lr=0.0002, beta_1=0.5))
    return model

def build_discriminator():
    model = Sequential()
    model.add(Conv2D(64, kernel_size=4, strides=2, padding="same", input_shape=(28, 28, 1)))
    model.add(LeakyReLU(0.2))
    model.add(Conv2D(128, kernel_size=4, strides=2, padding="same"))
    model.add(LeakyReLU(0.2))
    model.add(Flatten())
    model.add(Dense(1, activation="sigmoid"))
    model.compile(loss="binary_crossentropy", optimizer=Adam(lr=0.0002, beta_1=0.5))
    return model

# 实例化生成器和判别器模型
generator = build_generator()
discriminator = build_discriminator()

4)定义GAN模型:

def build_gan(generator, discriminator):
    discriminator.trainable = False
    gan_input = Input(shape=(100,))
    gan_output = discriminator(generator(gan_input))
    gan = Model(gan_input, gan_output)
    gan.compile(loss="binary_crossentropy", optimizer=Adam(lr=0.0002, beta_1=0.5))
    return gan

# 实例化GAN模型
gan = build_gan(generator, discriminator)

5)训练GAN模型:

def train_gan(gan, generator, discriminator, X_train, epochs=50, batch_size=128, sample_interval=200):
    for epoch in range(epochs):
        # 随机选择一批真实图像
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        real_images = X_train[idx]

        # 生成一批噪声作为输入
        noise = np.random.normal(0, 1, (batch_size, 100))

        # 生成假图像
        generated_images = generator.predict(noise)

        # 训练判别器
        discriminator_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
        discriminator_loss_fake = discriminator.train_on_batch(generated_images, np.zeros((batch_size, 1)))
        discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)

        # 训练生成器
        noise = np.random.normal(0, 1, (batch_size, 100))
        generator_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))

        # 输出训练过程中的损失
        if epoch % sample_interval == 0:
            print(f"Epoch {epoch}: discriminator loss = {discriminator_loss}, generator loss = {generator_loss}")
            # 保存生成的图像
            sample_images(generator, epoch)

# 定义保存生成图像的函数
def sample_images(generator, epoch):
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r * c, 100))
    generated_images = generator.predict(noise)
    generated_images = 0.5 * generated_images + 0.5
    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(generated_images[cnt, :, :, 0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig(f"images/mnist_{epoch}.png")
    plt.close()

# 开始训练GAN模型
train_gan(gan, generator, discriminator, X_train)

通过以上步骤,你可以使用Keras实现一个简单的GAN模型来生成MNIST数字图像。训练过程中,生成器和判别器模型会相互竞争,生成器尝试生成接近真实图像的假图像,而判别器则尝试区分真实图像和假图像。随着训练的进行,生成器会逐渐学习生成逼真的图像。

Tags:

最近发表
标签列表