iT邦幫忙

第 12 屆 iThome 鐵人賽

DAY 5
0

GAN如何實現

今天我們就要來實際來實作一個GAN!
code參考這篇製作simple_keras_GAN

前置作業

我們使用Colab來當作我們的實作平台,並使用Keras來完成。

資料集

MNIST 手寫數字數據集,為Keras內建的資料集
訓練集為 60,000 張 28x28 像素灰度圖像,測試集為 10,000 同規格圖像,總共 10 類數字標籤,0~10。

Code

前置處理

讀取資料集後,把數值都scale到-1~1之間。

from keras.layers import Input, Dense, Conv1D, Conv2D, MaxPooling1D,\
    MaxPooling2D, UpSampling1D, UpSampling2D, Dropout, Lambda, Convolution2D,\
    Reshape, Activation, Flatten, add, concatenate,BatchNormalization,LeakyReLU
from keras.models import Model, Sequential
import numpy as np
import keras
import pickle as pkl
import tensorflow as tf
import matplotlib.pyplot as plt
from keras.datasets import mnist
from keras import optimizers
from keras.optimizers import Adam
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = (x_train.astype(np.float32) - 127.5) / 127.5
x_train = np.reshape(x_train, (x_train.shape[0], x_train.shape[1], x_train.shape[2], 1))
width=28
height=28
channels=1
shape = (width,height,channels)

generator

def generator():
  input_shape=(100,)
  input = Input(input_shape)

  layer=Dense(16)(input)
  layer=LeakyReLU(alpha=0.2)(layer)
  layer=BatchNormalization(momentum=0.8)(layer)

  layer=Dense(32)(layer)
  layer=LeakyReLU(alpha=0.2)(layer)
  layer=BatchNormalization(momentum=0.8)(layer)

  layer=Dense(256)(layer)
  layer=LeakyReLU(alpha=0.2)(layer)
  layer=BatchNormalization(momentum=0.8)(layer)

  output = Dense(width*height*channels, activation='tanh')(layer)
  output = Reshape((width, height, channels))(output)
  model = Model(inputs=[input], outputs=[output])
  model.compile(loss=losses.binary_crossentropy,optimizer=optimizer)
  return model

discriminator

def discriminator():
  input_shape=shape
  input = Input(input_shape)
  layer=Flatten()(input)
  layer=Dense(width*height*channels)(layer)
  layer=LeakyReLU(alpha=0.2)(layer)
  layer=Dense(int(width*height*channels/2))(layer)
  layer=LeakyReLU(alpha=0.2)(layer)

  output=Dense(1, activation='sigmoid')(layer)
  model = Model(inputs=[input], outputs=[output])
  model.compile(loss=losses.binary_crossentropy,optimizer=optimizer, metrics=['accuracy'])
  return model

連結generator與discriminator

def stacked_generator_discriminator(G,D):
  D.trainable = False
  input_shape=(100,)
  input = Input(input_shape)
  layer=G(input)
  output=D(layer)
  model = Model(inputs=[input], outputs=[output])
  return model 

畫出圖形

def plot_images(save2file=False, samples=16,name="/mnist_" ,step=0):
  import math
  ''' Plot and generated images '''
  filename = IMAGE_DIR+name+"_%d.png" % step
  noise = np.random.normal(0, 1, (samples, 100))

  images = G.predict(noise)

  plt.figure(figsize=(10, 10))

  for i in range(images.shape[0]):
      plt.subplot(math.sqrt(samples), math.sqrt(samples), i+1)
      image = images[i, :, :, :]
      image = np.reshape(image, [height, width])
      plt.imshow(image, cmap='gray')
      plt.axis('off')
  plt.tight_layout()

  if save2file:
      plt.savefig(filename)
      plt.close('all')
  else:
      plt.show()

訓練

def train(G,D,model, X_train, epochs=10000, batch = 32, save_interval = 100):
  for cnt in range(epochs):
    # train discriminator
    random_index = np.random.randint(0, len(X_train) - batch/2)
    legit_images = X_train[random_index : random_index + int(batch/2)].reshape(int(batch/2), width, height, channels)

    gen_noise = np.random.normal(0, 1, (int(batch/2), 100)) 
    syntetic_images = G.predict(gen_noise)

    x_combined_batch = np.concatenate((legit_images, syntetic_images))
    y_combined_batch = np.concatenate((np.ones((int(batch/2), 1)), np.zeros((int(batch/2), 1))))
    d_loss = D.train_on_batch(x_combined_batch, y_combined_batch)


    # train generator

    noise = np.random.normal(0, 1, (batch, 100)) 
    y_mislabled = np.ones((batch, 1))
    
    g_loss = model.train_on_batch(noise, y_mislabled)
    print ('epoch: %d, [Discriminator :: d_loss: %f], [ Generator :: loss: %f]' % (cnt, d_loss[0], g_loss))

    if cnt % save_interval == 0:
      plot_images(save2file=True, step=cnt)
  tag = 'GAN0829_G{}'.format(0)
  h5_weight_path = os.path.join(WEIGHT_DIR, './' + tag + '.h5')
  G.save(h5_weight_path)
  tag = 'GAN0829_D{}'.format(0)
  h5_weight_path = os.path.join(WEIGHT_DIR, './' + tag + '.h5')
  D.save(h5_weight_path)

結果

from keras import losses
optimizer="adadelta"
G=generator()
D=discriminator()

G.compile(loss='binary_crossentropy', optimizer=optimizer)
D.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
model=stacked_generator_discriminator(G,D)
model.compile(loss=losses.binary_crossentropy,optimizer=optimizer)

train(G,D,model, x_train, epochs=20000, batch = 128, save_interval = 100)

結論

今天實作了GAN,並且結果還算不錯!

參考資料

simple_keras_GAN
[實戰系列] 使用 Keras 搭建一個 GAN 魔法陣(模型)


上一篇
Day 4 GAN是一種X話吧?
下一篇
Day 6 強化學習就是一直學習?
系列文
Machine Learning與軟工是否搞錯了什麼?30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言