iT邦幫忙

第 11 屆 iThome 鐵人賽

DAY 27
1
Google Developers Machine Learning

Towards Tensorflow 2.0系列 第 27

[Day-27] 生成對抗網路 (GAN) 實作 Part II

今天我們來實際來跑簡單的Dataset,就是 DL 101 資料集 - MNIST。透過較為簡單的Dataset 來理解像GAN這種相對難的演算法,應該能較容易理解GAN!

首先,可以直接透過 tf 的 api 來 load MNIST 資料集:

(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()

這邊的部分,我們就先拿train 的資料集,畢竟是Unsupervised。

接下來可以做簡單的前處理,並轉成tf 的格式:

x_train = x_train.astype(np.float32) / 255.
train_data = tf.data.Dataset.from_tensor_slices(x_train).shuffle(batch_size*6).batch(batch_size).repeat()
train_data_iter = iter(train_data)

資料處理完後,我們可以先來定義 Generator 以及 Discriminiator 。這部分的話看可以看前一天的部分,前一天有完整的 Model define Generator 以及 Discriminitator。這邊的話,就簡單測試一下Model 產生的 樣態

g = Generator()
d = Discriminator()

noise = tf.random.normal([1, 100])
generated_image = g(noise, training=False)

plt.imshow(generated_image[0, :, :, 0], cmap='gray')

以及

decision = d(generated_image)
print (decision)

https://ithelp.ithome.com.tw/upload/images/20191012/201199717iMlsZHzlV.png

可以從中簡單看出 Generator Gen出來的圖片,以及 當 Discriminator 判斷出來的 (未train)

接下來就可以簡單的做Model summary了

generator = Generator()
generator.build(input_shape=(batch_size, z_dim))
generator.summary()
discriminator = Discriminator()
discriminator.build(input_shape=(batch_size, 28, 28, 1))
discriminator.summary()

https://ithelp.ithome.com.tw/upload/images/20191012/20119971u2mxau9rwK.png

Model summary後,我們要來定義GAN的Loss,而 Generator 跟 Discriminator 的 loss是不太一樣的。

首先我們先從 Discriminator 開始:

Discriminator的部分,主要概念就是比較 Generator 所產生的假圖片跟真圖片來做比較!

def dis_loss(generator, discriminator, input_noise, real_image, is_trainig):
    fake_image = generator(input_noise, is_trainig)
    d_real_logits = discriminator(real_image, is_trainig)
    d_fake_logits = discriminator(fake_image, is_trainig)

    d_loss_real = loss_real(d_real_logits)
    d_loss_fake = loss_fake(d_fake_logits)
    loss = d_loss_real + d_loss_fake
    return loss

Generator的部分,就是透過 noise 產生圖片後,想辦法去騙過 Discriminator 的結果 (fake_loss)。

def gen_loss(generator, discriminator, input_noise, is_trainig):
    fake_image = generator(input_noise, is_trainig)
    fake_loss = discriminator(fake_image, is_trainig)
    loss = loss_real(fake_loss)
    return loss

而其中 loss_real 以及 loss_fake 就是透過 Discriminator output 的loss去計算 tf.nn.sigmoid_cross_entropy。 比較有差距的就是 labels 的部分,一個是for 真實圖片的loss (label 為1) 一個是for 假圖片的loss (label為0)

def loss_real(logits):
    return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,                                                                 labels=tf.ones_like(logits)))

def loss_fake(logits):
    return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,                                                                  labels=tf.zeros_like(logits)))

定義完loss後,我們可以來寫train的period,記得 Generator 跟 Discriminator 的 tf.GradientTape() 要分開寫!然後weight 也是分開 update。

for epoch in range(epochs):

    batch_x = next(train_data_iter)
    batch_x = tf.reshape(batch_x, shape=inputs_shape)
    batch_x = batch_x * 2.0 - 1.0
    batch_z = tf.random.normal(shape=[batch_size, z_dim])

    with tf.GradientTape() as tape:
        d_loss = dis_loss(generator, discriminator, batch_z, batch_x, is_training)
    grads = tape.gradient(d_loss, discriminator.trainable_variables)
    d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))
    with tf.GradientTape() as tape:
        g_loss = gen_loss(generator, discriminator, batch_z, is_training)
    grads = tape.gradient(g_loss, generator.trainable_variables)
    g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))
    
    if epoch % 100 == 0:
        print(epoch, 'd loss:', float(d_loss), 'g loss:', float(g_loss))

最後我們就來隨機 gen一張圖

https://ithelp.ithome.com.tw/upload/images/20191012/201199711yPPhtuQmL.png

是不是有點看起來像是手寫數字了 XD

小結 :

GAN真的是一個蠻好玩的演算法,有興趣大家可以在網路上找資源!其時候多人已經最好很棒的transfer learning for GAN 。 Ex: 圖像風格轉變 等等。謝謝大家今天漫長的閱讀 ~ 明天是最後一天連假,祝福大家明天最後一天連假愉快!

一天一梗圖:

https://ithelp.ithome.com.tw/upload/images/20191012/20119971nIbQ4jTcBr.png

source

Reference:

GAN_MNIST_Colab


上一篇
[Day-26] 生成對抗網路 (GAN) 實作 Part I
下一篇
[Day-28] 增強式學習 (Reinforcement learning) 介紹
系列文
Towards Tensorflow 2.030
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言