iT邦幫忙

第 12 屆 iThome 鐵人賽

DAY 30
0

今天用MINST資料集,來實作GAN,也因為GAN真的是個很難的演算法,所以想說透過比較簡單的data來呈現出來,今天的實作是參考https://colab.research.google.com/drive/1VafNpL9Tfh360769H4_9R7nEmNOsk78-#scrollTo=SVbk8Ya0EBRC

透過tf的api匯入MINST資料集並轉呈TF格式

batch_size = 256
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
x_train = x_train.astype(np.float32) / 255.
train_data = tf.data.Dataset.from_tensor_slices(x_train).shuffle(batch_size*4).batch(batch_size).repeat()
train_data_iter = iter(train_data_iter)
inputs_shape = [-1, 28, 28, 1]

因為昨天已經定義過Generator 和 Discriminator這裡可以簡單測試一下產生的狀態。

import imageio
import matplotlib.pyplot as plt
noise = tf.random.normal([1, 100])
generated_image = g(noise, training=False)
plt.imshow(generated_image[0, :, :, 0], cmap='gray')

https://ithelp.ithome.com.tw/upload/images/20201010/20130246dYpAoJ3RJ1.png

架設網路

generator = Generator()
generator.build(input_shape=(batch_size, z_dim))
generator.summary()
discriminator = Discriminator()
discriminator.build(input_shape=(batch_size, 28, 28, 1))
discriminator.summary()
d_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
g_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)

https://ithelp.ithome.com.tw/upload/images/20201010/20130246GPl8rlODOa.png
現在用Discriminator,來比較Generator所產生的假照片

def dis_loss(generator, discriminator, input_noise, real_image, is_trainig):
    fake_image = generator(input_noise, is_trainig)
    d_real_logits = discriminator(real_image, is_trainig)
    d_fake_logits = discriminator(fake_image, is_trainig)

    d_loss_real = loss_real(d_real_logits)
    d_loss_fake = loss_fake(d_fake_logits)
    loss = d_loss_real + d_loss_fake
    return loss

現在換Generator透過noise產生照片後,想辦法騙過Discriminator

def gen_loss(generator, discriminator, input_noise, is_trainig):
    fake_image = generator(input_noise, is_trainig)
    fake_loss = discriminator(fake_image, is_trainig)
    loss = loss_real(fake_loss)
    return loss

loss_real和loss_fake是透過Discriminator去計算

def loss_real(logits):
    return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
                                                                  labels=tf.ones_like(logits)))
def loss_fake(logits):
    return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
                                                                 labels=tf.zeros_like(logits)))

loss結束後就要開始訓練Generator和Discriminator。

for epoch in range(epochs):
    batch_x = next(train_data_iter)
    batch_x = tf.reshape(batch_x, shape=inputs_shape)
    batch_x = batch_x * 2.0 - 1.0
    batch_z = tf.random.normal(shape=[batch_size, z_dim])
with tf.GradientTape() as tape:
    d_loss = dis_loss(generator, discriminator, batch_z, batch_x, is_training)
    grads = tape.gradient(d_loss, discriminator.trainable_variables)
    d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))

with tf.GradientTape() as tape:
    g_loss = gen_loss(generator, discriminator, batch_z, is_training)
    grads = tape.gradient(g_loss, generator.trainable_variables)
    g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))
if epoch % 100 == 0:
    print(epoch, 'd loss:', float(d_loss), 'g loss:', float(g_loss))

https://ithelp.ithome.com.tw/upload/images/20201010/20130246D5eWnOqy4x.png
隨機產生一張圖片

noise = tf.random.normal([1, 50])
generated_image = generator(noise, training=False)
plt.imshow(generated_image[0, :, :, 0], cmap='gray')

https://ithelp.ithome.com.tw/upload/images/20201010/20130246JQjunpGmjk.png

結論

終於結束了,很慶幸自己生活在資訊這麼發達的環境。雖然這次TensorFlow2.0學得零零落落的,但讓我對AI&ML有最基本的架構和想法,最後感謝這30天提供我參考的大神和我的隊友。


上一篇
GAN實作(一)
系列文
Tensorflow2.030
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言