iT邦幫忙

第 11 屆 iT 邦幫忙鐵人賽

DAY 9
0

Batch Normalization 與 Dropout 是兩個預防模型過擬合的方法,雖然在訓練時,只要簡單幾行就能將之裝上去,但我這次想介紹的是以更深入細節來認識 tensorflow 的 Batch Normalization 和 Dropout。

首先,先從建造模型開始,之前在介紹 convolution 層時,我用了四種方式,這次我選擇用 tf.layers 庫來示範。

tf.layers.batch_normalization
tf.layers.dropout

如果常常會忘記這些 API 哪些參數要放什麼的話,可以查官方文件,或者直接看程式原始碼(像我用 pycharm 按住 command 鍵再點一下就進去了。)

https://ithelp.ithome.com.tw/upload/images/20190917/20107299mKpIhem0Z1.png
https://ithelp.ithome.com.tw/upload/images/20190917/201072991sKuL3Dczw.png

我們定義一個簡單模型,第一層用convolution,後面接batch_normalization,再一次convolution,後面接dropout,最後用identity單獨把結果拉出來方便在tensorboard觀察。

input_node = tf.placeholder(shape=[None, 100, 100, 3], 
                                dtype=tf.float32, 
                                name='input_node')
training_node = tf.placeholder(shape=(), 
                               dtype=tf.bool, 
                               name='training')

net = tf.layers.conv2d(input_node, 32, (3, 3), 
                       strides=(2, 2), 
                       padding='same', 
                       name='conv_1')
net = tf.layers.batch_normalization(net, 
                                    training=training_node, 
                                    name='bn')

net = tf.layers.conv2d(net, 32, (3, 3), 
                       strides=(1, 1), 
                       padding='same', 
                       name='conv_2')
net = tf.layers.dropout(net, 
                        rate=0.6, 
                        training=training_node, 
                        name='dropout')

tf.identity(net, name='final')

希望以上的程式碼對大家來說是很輕而易舉地實現。
接著產生session,初始化參數,封住 graph 並存成 pb 檔,並產生tfevent。

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    frozen_graph = tf.graph_util.convert_variables_to_constants(
        sess, tf.get_default_graph().as_graph_def(), ['final'])

    tf.summary.FileWriter(OUTPUT_PATH, graph=frozen_graph)
    tf.io.write_graph(frozen_graph, "../pb/", "bn_dropout_model.pb", as_text=False)

我們就可以在 tensorboard 上看到如下圖網路。
https://ithelp.ithome.com.tw/upload/images/20190917/20107299tcurf35x5Y.png

後面幾天會在針對這兩種 layer 做更詳細的介紹。

github原始碼


上一篇
【08】tensorflow 細看存檔:load pb篇
下一篇
【10】從 tensorboard 來觀察:你容易忽略的 batch_normalization 原理篇
系列文
How to Train Your Model 訓模高手:我的 Tensorflow 個人使用經驗系列文31

尚未有邦友留言

立即登入留言