在經過前二天的介紹,相信大家對 GAN 已經有一個概念了。今天就來從程式碼學習看看吧,先從閱讀前輩的程式碼開始。一樣是使用 colab 的環境,今天要研讀的是 Rajaswa Patil 所分享的程式 [Training GANs using Google Colaboratory]。這是用 torch 的架構,用的資料集是 mnist。也就是說試著用 GAN 去訓練產生手寫數字的模型。另外也有找到一篇是用 tensorflow 寫的,用的資料集是 fashion-mnist,這二篇試著跑過一次,都成功完成,看今天如果來不及介紹就明天再介紹 tensorflow 寫的囉。
先來介紹第一篇,原始的程式碼請看作者文章的連結,作者文章也有詳細的介紹,我們就記錄一下比較重點的地方,並且我會簡化一些參數,所以不要直接拿這篇的程式碼去跑唷。
定義資料集,在這裡使用的數據集是 MNIST 手寫數字圖像數據集。創建Dataset類,用 datasets.MNIST 下載 MNIST 手寫數字圖像數據集並初始化數據加載器,加載器會將數據送到模型中。
def mnist_data():
compose = transforms.Compose(
[transforms.ToTensor(),
# transforms.Normalize((.5, .5, .5), (.5, .5, .5))
transforms.Normalize([0.5], [0.5])
])
out_dir = './dataset'
return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)
data = mnist_data()
在 compose 這段程式碼 colab 執行時會有錯誤,要改一下程式碼。
再來定義鑑別網路 DiscriminatorNet,他定義了三層的隱藏層。
self.hidden0 = nn.Sequential(
nn.Linear(n_features,1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3)
)
self.hidden1 = nn.Sequential(
nn.Linear(1024,512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3)
)
self.hidden2 = nn.Sequential(
nn.Linear(512,256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3)
)
在生產網路 GeneratorNet,一樣也定義三層隱藏層。
self.hidden0 = nn.Sequential(
nn.Linear(n_features,256),
nn.LeakyReLU(0.2)
)
self.hidden1 = nn.Sequential(
nn.Linear(256,512),
nn.LeakyReLU(0.2)
)
self.hidden2 = nn.Sequential(
nn.Linear(512,1024),
nn.LeakyReLU(0.2)
)
接下來,繼續定義
最後就是全部串起來開始跑囉。
訓練鑑別器。
train_discriminator(D優化器, 真樣本, 假樣本)
訓練生成器。
train_generator(G優化器, 假樣本)
最後就是一些訓練過程的截圖,可以看到從一開始的雜訊圖樣慢慢呈現出手寫數字的樣子。
一開始的雜訊圖樣。
epoch:4。雜訊稍微減少一點。
epoch:20。看起來有點像數字,但很殘破。
epoch:143。已經很像手寫數字了。
以上就是 Rajaswa Patil 所分享程式的介紹,詳細的程式碼請參考作者的原稿唷。
好,第25天,結束。
參考
Training GANs using Google Colaboratory
ADAM: A METHOD FOR STOCHASTIC OPTIMIZATION
機器學習ML NOTE SGD,Momentum,AdaGrad,Adam Optimizer