今天發現大家常用來實作的一個模型訓練,就是用 MNIST 數據集訓練生成對抗網路(GAN),因為 MNIST 有大量手寫數字圖像,很適合測試生成模型的效果。再實作開始前,一樣先來認識生成對抗網路。
生成對抗網路(generative adversarial networks)
實作
先導入庫和模組:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.optimizers import Adam
(X_train, _), (_, _) = mnist.load_data()
X_train = X_train / 255.0
X_train = np.expand_dims(X_train, axis=-1)
batch_size = 128
z_dim = 100
生成器模型:
def build_generator(z_dim):
model = Sequential()
model.add(Dense(256, input_dim=z_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(28 * 28 * 1, activation='tanh'))
model.add(Reshape((28, 28, 1)))
return model
判別器模型:
def build_discriminator(img_shape):
model = Sequential()
model.add(Flatten(input_shape=img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.3))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.3))
model.add(Dense(1, activation='sigmoid'))
return model
建立GAN模型:
generator = build_generator(z_dim)
discriminator = build_discriminator((28, 28, 1))
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])
建立對抗模型:
z = Input(shape=(z_dim,))
img = generator(z)
discriminator.trainable = False
validity = discriminator(img)
gan = Model(z, validity)
gan.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
訓練模型的部分比較複雜,先是隨機選擇真實圖像:
def train(epochs, batch_size=128):
for epoch in range(epochs):
idx = np.random.randint(0, X_train.shape[0], batch_size)
real_imgs = X_train[idx]
生成假圖像和給標籤:
noise = np.random.normal(0, 1, (batch_size, z_dim))
fake_imgs = generator.predict(noise)
real_labels = np.ones((batch_size, 1))
fake_labels = np.zeros((batch_size, 1))
訓練判別器和生成器:
d_loss_real = discriminator.train_on_batch(real_imgs, real_labels)
d_loss_fake = discriminator.train_on_batch(fake_imgs, fake_labels)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
noise = np.random.normal(0, 1, (batch_size, z_dim))
g_loss = gan.train_on_batch(noise, real_labels)
輸出訓練進度並檢查 g_loss 的返回值:
if epoch % 100 == 0:
d_loss_value = d_loss[0]
d_loss_acc = d_loss[1]
if isinstance(g_loss, list):
g_loss_value = g_loss[0]
else:
g_loss_value = g_loss
print(f"{epoch} [D loss: {d_loss_value:.4f}, acc.: {100 * d_loss_acc:.2f}%] [G loss: {g_loss_value:.4f}]")
sample_images(epoch)
最後開始訓練:
def sample_images(epoch):
noise = np.random.normal(0, 1, (25, z_dim))
generated_imgs = generator.predict(noise)
generated_imgs = generated_imgs.reshape(25, 28, 28)
plt.figure(figsize=(5, 5))
for i in range(generated_imgs.shape[0]):
plt.subplot(5, 5, i + 1)
plt.imshow(generated_imgs[i], interpolation='nearest', cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.savefig(f"gan_mnist_epoch_{epoch}.png")
plt.close()
train(epochs=10, batch_size=batch_size)
小心得:這次訓練模型花的時間比我以前都還要久很多,去查詢後得到的結論是因為 GAN 由兩個神經網路組成,它們通常包含多層神經元,就會需要大量的計算資源來更新權重和計算損失。