iT邦幫忙

2024 iThome 鐵人賽

DAY 15
0
AI/ ML & Data

AI Unlocked: 30 Days to AI Brilliance系列 第 15

神秘數字的創造者:用GAN生成手寫數字的旅程

  • 分享至 

  • xImage
  •  

GAN的魅力在於它能夠生成逼真的圖像,而在這篇文章中,我們將探索如何利用GAN來生成手寫數字。

  1. 導入所需的庫

import numpy as np import matplotlib.pyplot as plt import tensorflow as tf from tensorflow.keras import layers

  1. 加載和準備數據集
    使用MNIST數據集做訓練
    https://ithelp.ithome.com.tw/upload/images/20240929/20169257Yyh4z5vtuz.jpg
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255.0  # 正規化
x_train = np.expand_dims(x_train, axis=-1)  # 增加通道維度
  1. 定義生成器模型
    生成器會從隨機噪聲中生成手寫數字。
def build_generator():
    model = tf.keras.Sequential()
    model.add(layers.Dense(256, input_dim=100, activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(512, activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(1024, activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(28 * 28 * 1, activation='tanh'))
    model.add(layers.Reshape((28, 28, 1)))
    return model
  1. 定義判別器模型
    判別器會判斷圖像是真實的還是由生成器生成的。
def build_discriminator():
    model = tf.keras.Sequential()
    model.add(layers.Flatten(input_shape=(28, 28, 1)))
    model.add(layers.Dense(512, activation='relu'))
    model.add(layers.Dense(256, activation='relu'))
    model.add(layers.Dense(1, activation='sigmoid'))  # 二元分類
    return model
  1. 構建和編譯GAN模型
    將生成器和判別器連接起來,並編譯GAN模型。
generator = build_generator()
discriminator = build_discriminator()

discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# GAN模型
discriminator.trainable = False  # 固定判別器
gan_input = layers.Input(shape=(100,))
generated_image = generator(gan_input)
gan_output = discriminator(generated_image)
gan = tf.keras.Model(gan_input, gan_output)
gan.compile(loss='binary_crossentropy', optimizer='adam')
  1. 訓練GAN模型
    訓練過程中,我們將不斷生成圖像並訓練判別器。
def train_gan(epochs, batch_size):
    for epoch in range(epochs):
        # 隨機選擇真實圖像
        idx = np.random.randint(0, x_train.shape[0], batch_size)
        real_images = x_train[idx]

        # 生成假圖像
        noise = np.random.normal(0, 1, (batch_size, 100))
        generated_images = generator.predict(noise)

        # 訓練判別器
        discriminator_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
        discriminator_loss_fake = discriminator.train_on_batch(generated_images, np.zeros((batch_size, 1)))
        discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)

        # 訓練生成器
        noise = np.random.normal(0, 1, (batch_size, 100))
        generator_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))

        if epoch % 100 == 0:
            print(f"Epoch: {epoch}, Discriminator Loss: {discriminator_loss[0]}, Generator Loss: {generator_loss}")

train_gan(epochs=10000, batch_size=64)
  1. 生成和顯示手寫數字
    在訓練結束後,生成一些手寫數字並顯示出來。
def generate_and_display_images(num_images):
    noise = np.random.normal(0, 1, (num_images, 100))
    generated_images = generator.predict(noise)
    generated_images = 0.5 * generated_images + 0.5  # 反正規化到0到1之間

    plt.figure(figsize=(10, 10))
    for i in range(num_images):
        plt.subplot(1, num_images, i + 1)
        plt.imshow(generated_images[i, :, :, 0], cmap='gray')
        plt.axis('off')
    plt.show()

generate_and_display_images(10)

這項技術不僅展示了人工智慧的強大能力,也讓我們能夠創造出全新的數位藝術作品。GAN的潛力無窮,它在影像生成、風格轉換等多個領域都有著廣泛的應用。


上一篇
GAN的雙面性:解決訓練不穩定與探索變種模型
下一篇
智慧金融防線:AI 如何運用於信用評分與欺詐檢測的技術突破
系列文
AI Unlocked: 30 Days to AI Brilliance30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言