iT邦幫忙

2024 iThome 鐵人賽

DAY 27
0
自我挑戰組

30天初探tensorflow之旅系列 第 27

Day 27 用MNIST訓練生成對抗網路(GAN)

  • 分享至 

  • xImage
  •  

今天發現大家常用來實作的一個模型訓練,就是用 MNIST 數據集訓練生成對抗網路(GAN),因為 MNIST 有大量手寫數字圖像,很適合測試生成模型的效果。再實作開始前,一樣先來認識生成對抗網路。

生成對抗網路(generative adversarial networks)

  • 簡介:
    它是非監督式學習的一種方法,通過兩個神經網路相互博弈的方式進行學習。生成網路從潛在空間中隨機取樣作為輸入,輸出結果需要盡量模仿訓練集中的真實樣本。判別網路的輸入則為真實樣本或生成網路的輸出,目的是將生成網路的輸出從真實樣本中盡可能分辨出來,而生成網路則要盡可能地欺騙判別網路。兩個網路相互對抗、不斷調整參數,最終目的是使判別網路無法判斷生成網路的輸出結果是否真實。
  • 基本原理:
    1.生成器(Generator):
    生成器的目標是創造出看起來真實的數據,它接收隨機噪聲作為輸入,並通過多層神經網絡生成樣本(像是圖片)。
    2.判別器(Discriminator):
    判別器的任務是區分真實數據和生成數據,它接收一組數據並輸出一個概率,表示輸入數據是真實的還是生成的。

實作
先導入庫和模組:

import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.optimizers import Adam

(X_train, _), (_, _) = mnist.load_data()
X_train = X_train / 255.0
X_train = np.expand_dims(X_train, axis=-1)
batch_size = 128
z_dim = 100

生成器模型:

def build_generator(z_dim):
    model = Sequential()
    model.add(Dense(256, input_dim=z_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(28 * 28 * 1, activation='tanh'))
    model.add(Reshape((28, 28, 1)))
    return model

判別器模型:

def build_discriminator(img_shape):
    model = Sequential()
    model.add(Flatten(input_shape=img_shape))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))
    model.add(Dense(1, activation='sigmoid'))
    return model

建立GAN模型:

generator = build_generator(z_dim)
discriminator = build_discriminator((28, 28, 1))
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])

建立對抗模型:

z = Input(shape=(z_dim,))
img = generator(z)
discriminator.trainable = False
validity = discriminator(img)
gan = Model(z, validity)
gan.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))

訓練模型的部分比較複雜,先是隨機選擇真實圖像:

def train(epochs, batch_size=128):
    for epoch in range(epochs):
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        real_imgs = X_train[idx]

生成假圖像和給標籤:

noise = np.random.normal(0, 1, (batch_size, z_dim))
        fake_imgs = generator.predict(noise)
real_labels = np.ones((batch_size, 1))
        fake_labels = np.zeros((batch_size, 1))

訓練判別器和生成器:

d_loss_real = discriminator.train_on_batch(real_imgs, real_labels)
        d_loss_fake = discriminator.train_on_batch(fake_imgs, fake_labels)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        noise = np.random.normal(0, 1, (batch_size, z_dim))
        g_loss = gan.train_on_batch(noise, real_labels)

輸出訓練進度並檢查 g_loss 的返回值:

if epoch % 100 == 0:
            d_loss_value = d_loss[0]
            d_loss_acc = d_loss[1]
            if isinstance(g_loss, list):
                g_loss_value = g_loss[0]
            else:
                g_loss_value = g_loss
            print(f"{epoch} [D loss: {d_loss_value:.4f}, acc.: {100 * d_loss_acc:.2f}%] [G loss: {g_loss_value:.4f}]")
            sample_images(epoch)

最後開始訓練:

def sample_images(epoch):
    noise = np.random.normal(0, 1, (25, z_dim))
    generated_imgs = generator.predict(noise)
    generated_imgs = generated_imgs.reshape(25, 28, 28)

    plt.figure(figsize=(5, 5))
    for i in range(generated_imgs.shape[0]):
        plt.subplot(5, 5, i + 1)
        plt.imshow(generated_imgs[i], interpolation='nearest', cmap='gray')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig(f"gan_mnist_epoch_{epoch}.png")
    plt.close()
train(epochs=10, batch_size=batch_size)

https://ithelp.ithome.com.tw/upload/images/20241011/20169330LKDTHQRJGM.png

小心得:這次訓練模型花的時間比我以前都還要久很多,去查詢後得到的結論是因為 GAN 由兩個神經網路組成,它們通常包含多層神經元,就會需要大量的計算資源來更新權重和計算損失。


上一篇
Day 26 關於長短期記憶網路(LSTM)
下一篇
Day 28 認識word2vec的訓練基本概念
系列文
30天初探tensorflow之旅30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言