iT邦幫忙

2025 iThome 鐵人賽

DAY 25
0

昨天我們成功實作了 DCGAN 並用於生成手寫數字,但問題是「我們沒辦法指定 GAN 具體會生成什麼內容」。因此今天將要學習,如何進行可控的生成。

cGAN

條件生成對抗網路 (conditional GAN, cGAN) 的想法非常直觀:在訓練和生成時,為生成器和判別器,都額外提供一個「條件」資訊 y。這個 y 可以是任何想要的控制變因,如圖像標籤、一段文字,或是另一張圖片。

cGAN 的結構也有所改動

  • 生成器:輸入不再只是隨機噪點 z,而是噪點 z 和條件 y 的拼接。生成器現在的任務是,學習生成一個既真實、又符合條件 y 的圖像 G(z|y)。

  • 判別器:輸入不再只是一張圖片 x,而是圖片 x 和條件 y 的拼接。判別器現在的任務是,判斷這張圖片 x 是否不僅真實,而且還與給定的條件 y 相匹配。

判別器除了懲罰很假的圖片,也會懲罰雖然看起來真實,但類別錯誤的圖片,達到生成對應內容的效果。

StyleGAN

2018 年由輝達推出的 StyleGAN 讓我們可以對生成結果的各個層次的「風格」 進行獨立、精細的控制,並取得了空前的成功。核心創新點包含

  1. 風格注入 (style injection)

    • 映射網路 (mapping network):StyleGAN 不再像傳統 GAN 那樣,直接將噪點 z 餵給生成器。它首先用一個 8 層的 ANN(映射網路 f),將輸入的潛在編碼 z,轉換成一個中間的潛在編碼 w。作者認為,這個 w 所在的潛在空間,比 z 空間能更好地分離不同的特徵。

    • 自適應實例正規化 (AdaIN):w 並不是只在生成器的開頭輸入一次,而是會被轉換成「風格 (style)」向量,並透過 AdaIN 這個操作,在生成器的每一層都被注入進去。這使得 w 能夠控制從低階特徵(如細微的皮膚紋理、髮絲)到高階特徵(如臉型、髮型、身份)的方方面面。

  2. 漸進式增長 (Progressive Growing):為了生成高解析度圖片(如 1024×1024),StyleGAN 借鑒了 ProGAN 的思想。它在訓練時,先從一個極低解析度(如 4×4)的生成器和判別器開始訓練,待其穩定後,再逐漸地、平滑地增加新的層,將解析度提升到 8×8, 16×16, ...,直到目標的 1024×1024。這種由易到難的訓練方式,極大地增強了訓練的穩定性和最終的生成品質。

  3. 隨機噪點輸入:為了生成更逼真的、隨機性的細節(例如雀斑、毛孔、頭髮的精確位置),StyleGAN 在生成器的每一層,還會額外輸入一些隨機的噪點圖。

使用 cGAN 生成手寫數字

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os

# --- 1. 設定超參數 (與 Day 24 類似,但增加了 num_classes) ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
latent_size = 100
hidden_size = 256
image_size = 784
num_classes = 10 # 0-9 共 10 個類別
num_epochs = 50
batch_size = 128
learning_rate = 0.0002

# --- 2. 載入 MNIST 數據集 ---
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

# --- 3. 修改生成器與判別器 ---
# 生成器 G
class Generator(nn.Module):
    def __init__(self, latent_size, hidden_size, image_size, num_classes):
        super(Generator, self).__init__()
        # 輸入維度是噪點 + 類別嵌入
        self.main = nn.Sequential(
            nn.Linear(latent_size + num_classes, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size * 2),
            nn.ReLU(),
            nn.Linear(hidden_size * 2, image_size),
            nn.Tanh()
        )
    def forward(self, z, labels):
        # 將噪點 z 和 one-hot 編碼的 labels 拼接
        cgan_input = torch.cat([z, labels], 1)
        return self.main(cgan_input)

# 判別器 D
class Discriminator(nn.Module):
    def __init__(self, image_size, hidden_size, num_classes):
        super(Discriminator, self).__init__()
        # 輸入維度是圖片 + 類別嵌入
        self.main = nn.Sequential(
            nn.Linear(image_size + num_classes, hidden_size * 2),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_size * 2, hidden_size),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_size, 1),
            nn.Sigmoid()
        )
    def forward(self, img, labels):
        # 將圖片和 one-hot 編碼的 labels 拼接
        cgan_input = torch.cat([img, labels], 1)
        return self.main(cgan_input)

G = Generator(latent_size, hidden_size, image_size, num_classes).to(device)
D = Discriminator(image_size, hidden_size, num_classes).to(device)

# --- 4. 損失函數與優化器 (與 Day 24 相同) ---
criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr=learning_rate)
g_optimizer = optim.Adam(G.parameters(), lr=learning_rate)

# --- 5. 訓練模型 (主要改動在輸入) ---
print("開始訓練 cGAN...")
if not os.path.exists('cgan_results'): os.makedirs('cgan_results')

def to_one_hot(labels, num_classes):
    """將標籤轉換為 one-hot 編碼"""
    one_hot = torch.zeros(labels.size(0), num_classes).to(device)
    one_hot.scatter_(1, labels.view(-1, 1), 1)
    return one_hot

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        batch_size = images.size(0)
        images = images.reshape(batch_size, -1).to(device)
        labels_one_hot = to_one_hot(labels, num_classes) # 將標籤轉為 one-hot
        
        real_labels_val = torch.ones(batch_size, 1).to(device)
        fake_labels_val = torch.zeros(batch_size, 1).to(device)
        
        # --- (1) 訓練判別器 ---
        outputs = D(images, labels_one_hot) # 傳入圖片和條件
        d_loss_real = criterion(outputs, real_labels_val)
        
        z = torch.randn(batch_size, latent_size).to(device)
        # 讓生成器也使用同樣的標籤條件
        fake_images = G(z, labels_one_hot)
        outputs = D(fake_images.detach(), labels_one_hot)
        d_loss_fake = criterion(outputs, fake_labels_val)
        
        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad(); d_loss.backward(); d_optimizer.step()
        
        # --- (2) 訓練生成器 ---
        outputs = D(fake_images, labels_one_hot)
        g_loss = criterion(outputs, real_labels_val)
        
        g_optimizer.zero_grad(); g_loss.backward(); g_optimizer.step()
        
        if (i+1) % 200 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')

    # --- 6. 生成指定的數字 ---
    with torch.no_grad():
        # 我們要生成 10 個數字 (0-9),每個生成 10 個樣本
        num_samples = 10
        fixed_z = torch.randn(num_samples * num_classes, latent_size).to(device)
        fixed_labels = torch.arange(0, num_classes).repeat(num_samples).to(device)
        fixed_labels_one_hot = to_one_hot(fixed_labels, num_classes)
        
        generated_images = G(fixed_z, fixed_labels_one_hot).view(-1, 1, 28, 28)
        save_image(generated_images, f'cgan_results/epoch_{epoch+1}.png', nrow=num_classes)
print("訓練完成!")

結果
Epoch 10
https://ithelp.ithome.com.tw/upload/images/20250904/20178100UxtpBMGFEm.png
Epoch 20
https://ithelp.ithome.com.tw/upload/images/20250904/20178100o0w9NWo1u4.png
Epoch 30
https://ithelp.ithome.com.tw/upload/images/20250904/20178100i88TFzwnIm.png
Epoch 40
https://ithelp.ithome.com.tw/upload/images/20250904/20178100T8JusOiAmW.png
Epoch 50
https://ithelp.ithome.com.tw/upload/images/20250904/20178100N4NPMNIRpI.png


上一篇
Day 24 - GAN
下一篇
Day 26 - Vision Transformer
系列文
從0開始:傳統圖像處理到深度學習模型26
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言