今天用MINST資料集,來實作GAN,也因為GAN真的是個很難的演算法,所以想說透過比較簡單的data來呈現出來,今天的實作是參考https://colab.research.google.com/drive/1VafNpL9Tfh360769H4_9R7nEmNOsk78-#scrollTo=SVbk8Ya0EBRC
batch_size = 256
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
x_train = x_train.astype(np.float32) / 255.
train_data = tf.data.Dataset.from_tensor_slices(x_train).shuffle(batch_size*4).batch(batch_size).repeat()
train_data_iter = iter(train_data_iter)
inputs_shape = [-1, 28, 28, 1]
因為昨天已經定義過Generator 和 Discriminator這裡可以簡單測試一下產生的狀態。
import imageio
import matplotlib.pyplot as plt
noise = tf.random.normal([1, 100])
generated_image = g(noise, training=False)
plt.imshow(generated_image[0, :, :, 0], cmap='gray')
generator = Generator()
generator.build(input_shape=(batch_size, z_dim))
generator.summary()
discriminator = Discriminator()
discriminator.build(input_shape=(batch_size, 28, 28, 1))
discriminator.summary()
d_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
g_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
現在用Discriminator,來比較Generator所產生的假照片
def dis_loss(generator, discriminator, input_noise, real_image, is_trainig):
fake_image = generator(input_noise, is_trainig)
d_real_logits = discriminator(real_image, is_trainig)
d_fake_logits = discriminator(fake_image, is_trainig)
d_loss_real = loss_real(d_real_logits)
d_loss_fake = loss_fake(d_fake_logits)
loss = d_loss_real + d_loss_fake
return loss
現在換Generator透過noise產生照片後,想辦法騙過Discriminator
def gen_loss(generator, discriminator, input_noise, is_trainig):
fake_image = generator(input_noise, is_trainig)
fake_loss = discriminator(fake_image, is_trainig)
loss = loss_real(fake_loss)
return loss
loss_real和loss_fake是透過Discriminator去計算
def loss_real(logits):
return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
labels=tf.ones_like(logits)))
def loss_fake(logits):
return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
labels=tf.zeros_like(logits)))
loss結束後就要開始訓練Generator和Discriminator。
for epoch in range(epochs):
batch_x = next(train_data_iter)
batch_x = tf.reshape(batch_x, shape=inputs_shape)
batch_x = batch_x * 2.0 - 1.0
batch_z = tf.random.normal(shape=[batch_size, z_dim])
with tf.GradientTape() as tape:
d_loss = dis_loss(generator, discriminator, batch_z, batch_x, is_training)
grads = tape.gradient(d_loss, discriminator.trainable_variables)
d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))
with tf.GradientTape() as tape:
g_loss = gen_loss(generator, discriminator, batch_z, is_training)
grads = tape.gradient(g_loss, generator.trainable_variables)
g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))
if epoch % 100 == 0:
print(epoch, 'd loss:', float(d_loss), 'g loss:', float(g_loss))
隨機產生一張圖片
noise = tf.random.normal([1, 50])
generated_image = generator(noise, training=False)
plt.imshow(generated_image[0, :, :, 0], cmap='gray')
終於結束了,很慶幸自己生活在資訊這麼發達的環境。雖然這次TensorFlow2.0學得零零落落的,但讓我對AI&ML有最基本的架構和想法,最後感謝這30天提供我參考的大神和我的隊友。