GAN的魅力在於它能夠生成逼真的圖像,而在這篇文章中,我們將探索如何利用GAN來生成手寫數字。
import numpy as np import matplotlib.pyplot as plt import tensorflow as tf from tensorflow.keras import layers
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255.0 # 正規化
x_train = np.expand_dims(x_train, axis=-1) # 增加通道維度
def build_generator():
model = tf.keras.Sequential()
model.add(layers.Dense(256, input_dim=100, activation='relu'))
model.add(layers.BatchNormalization())
model.add(layers.Dense(512, activation='relu'))
model.add(layers.BatchNormalization())
model.add(layers.Dense(1024, activation='relu'))
model.add(layers.BatchNormalization())
model.add(layers.Dense(28 * 28 * 1, activation='tanh'))
model.add(layers.Reshape((28, 28, 1)))
return model
def build_discriminator():
model = tf.keras.Sequential()
model.add(layers.Flatten(input_shape=(28, 28, 1)))
model.add(layers.Dense(512, activation='relu'))
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid')) # 二元分類
return model
generator = build_generator()
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# GAN模型
discriminator.trainable = False # 固定判別器
gan_input = layers.Input(shape=(100,))
generated_image = generator(gan_input)
gan_output = discriminator(generated_image)
gan = tf.keras.Model(gan_input, gan_output)
gan.compile(loss='binary_crossentropy', optimizer='adam')
def train_gan(epochs, batch_size):
for epoch in range(epochs):
# 隨機選擇真實圖像
idx = np.random.randint(0, x_train.shape[0], batch_size)
real_images = x_train[idx]
# 生成假圖像
noise = np.random.normal(0, 1, (batch_size, 100))
generated_images = generator.predict(noise)
# 訓練判別器
discriminator_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
discriminator_loss_fake = discriminator.train_on_batch(generated_images, np.zeros((batch_size, 1)))
discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)
# 訓練生成器
noise = np.random.normal(0, 1, (batch_size, 100))
generator_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))
if epoch % 100 == 0:
print(f"Epoch: {epoch}, Discriminator Loss: {discriminator_loss[0]}, Generator Loss: {generator_loss}")
train_gan(epochs=10000, batch_size=64)
def generate_and_display_images(num_images):
noise = np.random.normal(0, 1, (num_images, 100))
generated_images = generator.predict(noise)
generated_images = 0.5 * generated_images + 0.5 # 反正規化到0到1之間
plt.figure(figsize=(10, 10))
for i in range(num_images):
plt.subplot(1, num_images, i + 1)
plt.imshow(generated_images[i, :, :, 0], cmap='gray')
plt.axis('off')
plt.show()
generate_and_display_images(10)
這項技術不僅展示了人工智慧的強大能力,也讓我們能夠創造出全新的數位藝術作品。GAN的潛力無窮,它在影像生成、風格轉換等多個領域都有著廣泛的應用。