iT邦幫忙

第 11 屆 iThome 鐵人賽

DAY 25
0

在經過前二天的介紹,相信大家對 GAN 已經有一個概念了。今天就來從程式碼學習看看吧,先從閱讀前輩的程式碼開始。一樣是使用 colab 的環境,今天要研讀的是 Rajaswa Patil 所分享的程式 [Training GANs using Google Colaboratory]。這是用 torch 的架構,用的資料集是 mnist。也就是說試著用 GAN 去訓練產生手寫數字的模型。另外也有找到一篇是用 tensorflow 寫的,用的資料集是 fashion-mnist,這二篇試著跑過一次,都成功完成,看今天如果來不及介紹就明天再介紹 tensorflow 寫的囉。

先來介紹第一篇,原始的程式碼請看作者文章的連結,作者文章也有詳細的介紹,我們就記錄一下比較重點的地方,並且我會簡化一些參數,所以不要直接拿這篇的程式碼去跑唷。

定義資料集,在這裡使用的數據集是 MNIST 手寫數字圖像數據集。創建Dataset類,用 datasets.MNIST 下載 MNIST 手寫數字圖像數據集並初始化數據加載器,加載器會將數據送到模型中。

def mnist_data():
    compose = transforms.Compose(
        [transforms.ToTensor(),
         # transforms.Normalize((.5, .5, .5), (.5, .5, .5))
         transforms.Normalize([0.5], [0.5])
        ])
    out_dir = './dataset'
    return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)

data = mnist_data()

在 compose 這段程式碼 colab 執行時會有錯誤,要改一下程式碼。

  • 原來的 : transforms.Normalize((.5, .5, .5),(.5, .5, .5))
  • 修正後的: transforms.Normalize([0.5], [0.5])

再來定義鑑別網路 DiscriminatorNet,他定義了三層的隱藏層。

        self.hidden0 = nn.Sequential( 
            nn.Linear(n_features,1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.hidden1 = nn.Sequential(
            nn.Linear(1024,512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.hidden2 = nn.Sequential(
            nn.Linear(512,256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )

在生產網路 GeneratorNet,一樣也定義三層隱藏層。

        self.hidden0 = nn.Sequential(
            nn.Linear(n_features,256),
            nn.LeakyReLU(0.2)
        )
        self.hidden1 = nn.Sequential(            
            nn.Linear(256,512),
            nn.LeakyReLU(0.2)
        )
        self.hidden2 = nn.Sequential(
            nn.Linear(512,1024),
            nn.LeakyReLU(0.2)
        )

接下來,繼續定義

  • 創建噪聲採樣器(noise sampler)。
  • 優化器(Optimizers),這裡用的是Adam優化器。
    D優化器 = Adam(鑑別器的參數)
    G優化器 = Adam(生成器的參數)
  • 初始化損失函數(Loss fucntion)。

最後就是全部串起來開始跑囉。
訓練鑑別器。

 train_discriminator(D優化器, 真樣本, 假樣本)

訓練生成器。

 train_generator(G優化器, 假樣本)

最後就是一些訓練過程的截圖,可以看到從一開始的雜訊圖樣慢慢呈現出手寫數字的樣子。

一開始的雜訊圖樣。

epoch:4。雜訊稍微減少一點。

epoch:20。看起來有點像數字,但很殘破。

epoch:143。已經很像手寫數字了。

以上就是 Rajaswa Patil 所分享程式的介紹,詳細的程式碼請參考作者的原稿唷。

好,第25天,結束。
/images/emoticon/emoticon06.gif

參考
Training GANs using Google Colaboratory
ADAM: A METHOD FOR STOCHASTIC OPTIMIZATION
機器學習ML NOTE SGD,Momentum,AdaGrad,Adam Optimizer


上一篇
GAN Lab(2) - GAN Lab 介面
下一篇
GAN Lab(4) - CartoonGAN
系列文
「Google Machine Learning」學習筆記31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

1 則留言

0
阿展展展
iT邦好手 1 級 ‧ 2020-01-20 06:34:32

熟悉的123456789最對味 /images/emoticon/emoticon46.gif

哈,練習的好朋友,親切的資料集 ^__^

我要留言

立即登入留言