DAY 26
0
Google Developers Machine Learning

## [Day-26] 生成對抗網路 (GAN) 實作 Part I

#### Discriminator Network:

``````class Discriminator(keras.Model):
def __init__(self):
super(Discriminator,self).__init__()

self.conv_1 = layers.Conv2D(64,5,3,'valid')
self.conv_2 = layers.Conv2D(128,5,3,'valid')
self.bn_1 = layers.BatchNormalization()
self.conv_3 = layers.Conv2D(256,5,3,'valid')
self.bn_2 = layers.BatchNormalization()
self.flatten = layers.Flatten()
self.fc_layer = layers.Dense(1)

def call(self, inputs, training=None):
x = tf.nn.leaky_relu(self.conv_1(inputs))
x = tf.nn.leaky_relu(self.bn_1(self.conv_2(x),training=training))
x = tf.nn.leaky_relu(self.bn_2(self.conv_3(x),training=training))
x = self.flatten(x)
x = self.fc_layer(x)
return x
``````

#### Generator Network:

Generator的部分，主要為一個圖片產生器，透過一個低維度的matrix，還原成一張正常的圖片。在Generator中
，會使用 `tf.layers.Conv2DTranspose` (反卷積) ，簡單來說就是把特徵還原成圖片的概念 (如下圖)

Input -> Dense -> Conv Transpose -> BN -> .. -> Tanh

``````class Generator(keras.Model):
def __init__(self):
super(Generator,self).__init__()
#encoder
self.fc_layer_1 = layers.Dense(3*3*512)
self.conv_1 = layers.Conv2DTranspose(256,3,3,'valid')

self.bn_1 = layers.BatchNormalization()
self.conv_2 = layers.Conv2DTranspose(128,5,2,'valid')
self.bn_2 = layers.BatchNormalization()
self.conv_3 = layers.Conv2DTranspose(3,4,3,'valid')

def call(self, inputs, training=None):
x = self.fc_layer_1(inputs)
x = tf.reshape(x,[-1,3,3,512])
x = tf.nn.leaky_relu(x)
x = self.bn_1(self.conv_1(x),training=training)
x = self.bn_2(self.conv_2(x),training=training)
x = self.conv_3(x)
x = tf.tanh(x)
return x
``````

``````x = tf.random.normal([1,64,64,3])
z = tf.random.normal([1,100])
prob = g(x)
print(prob)
out = d(x)
print(out.shape)
``````

GAN_example