iT邦幫忙

第 11 屆 iThome 鐵人賽

DAY 17
1
Google Developers Machine Learning

How to Train Your Model 訓模高手:我的 Tensorflow 個人使用經驗系列文系列 第 17

【17】tensorflow 訓練技巧:運用 GraphKeys.UPDATE_OPS 來更新 Batch Normalization 權重篇

  • 分享至 

  • xImage
  •  

今天要介紹的東西,可能很多新手寫錯了都還不知道,包括我早期使用時,大家可以檢視一下自己的模型。
不知道大家還不記得 day10 所介紹的 batch normalization 嗎?當時我示範了在 training_node 為 True 和 False 下會有不同的行為,因此,今天要來詳細介紹,再這兩個不同行為上,到底要如何更新參數!

警告,閱讀這邊文章前,請先確實充分理解 batch normalization 原理,若觀念還很模糊者請先點擊上方 day10 連結複習。

一開始,先建立示範用的模型。

input_node = tf.placeholder(shape=[None, 100, 100, 3], dtype=tf.float32, name='input_node')
training_node = tf.placeholder_with_default(True, (), name='training')

net = tf.layers.conv2d(input_node, 32, (3, 3), strides=(2, 2), padding='same', name='conv_1')
net = tf.layers.batch_normalization(net, training=training_node, name='bn')

視覺化如圖:
https://ithelp.ithome.com.tw/upload/images/20190925/20107299ImK14yJbXF.png

再來是 batch normalization 的兩個重要的參數,moving_mean 和 moving_var,我們在這邊取得這兩個 tensor 以便觀察。

moving_mean = tf.get_default_graph().get_tensor_by_name(
   "bn/moving_mean/read:0")
moving_var = tf.get_default_graph().get_tensor_by_name(
   "bn/moving_variance/read:0")

關鍵的地方到了,既然 moving_mean 和 moving_var 要在training_node=True時計算,那該如何去更新它呢?其實在宣告 tf.layers.batch_normalization 時,tensorflow 會自動把它的 update operation 放進全域的變數區裡,要拿到這個 op,我們可以透過 tf.get_collection 來取得,示範如下。

update_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
print(f'update_op: {update_op}')

把 update_op 印出來可以看到:

update_op: [<tf.Operation 'bn/AssignMovingAvg' type=AssignSub>, <tf.Operation 'bn/AssignMovingAvg_1' type=AssignSub>]

那麼,現在我們要在 train_op 執行之前,先執行上面那個 update_op,這就會用到幾天前講到的 control_dependencies 來排定。

with tf.control_dependencies(update_op):
   train_op = tf.identity(net, name='train_op')

如此,只要執行一次 train_op,tensorflow 就會先跑 batch normalization 的 update_op。

我們實際run一次結果:

with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
   sess.run(tf.local_variables_initializer())

   image = cv2.imread('../05/ithome.jpg')
   image = np.expand_dims(image, 0)

   for _ in range(100):
       sess.run(train_op, feed_dict={input_node: image})

   result, mm, mv = sess.run([net, moving_mean, moving_var],
     feed_dict={input_node: image, training_node: False})
   print(f'with_update_op:\n(mm , mv) : ({mm[0]:.2f} , {mv[0]:.2f})\n{result[0, 22:28, 22:28, 0]}')

會印出:
https://ithelp.ithome.com.tw/upload/images/20190925/20107299RhjqSknIyX.png

我們看到 moving_mean 和 moving_var 分別是 -6.73 和 82.14,這一定是有經過 update_op 的計算所產生出的結果,另外,也可以看到輸出的數值也很明顯經過 normalization。

那如果沒有先經過 update_op 的 normalization 會長怎樣呢?我們來實驗看看

input_node = tf.placeholder(shape=[None, 100, 100, 3], dtype=tf.float32, name='input_node')
training_node = tf.placeholder_with_default(True, (), name='training')

net = tf.layers.conv2d(input_node, 32, (3, 3), strides=(2, 2), padding='same', name='conv_1')
net = tf.layers.batch_normalization(net, training=training_node, name='bn')

moving_mean = tf.get_default_graph().get_tensor_by_name(
   "bn/moving_mean/read:0")
moving_var = tf.get_default_graph().get_tensor_by_name(
   "bn/moving_variance/read:0")

train_op = tf.identity(net, name='train_op')

with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
   sess.run(tf.local_variables_initializer())

   image = cv2.imread('../05/ithome.jpg')
   image = np.expand_dims(image, 0)

   for _ in range(10):
       sess.run(train_op, feed_dict={input_node: image})

   result, mm, mv = sess.run([net, moving_mean, moving_var], feed_dict={input_node: image, training_node: False})
   print(f'without_update_op:\n(mm , mv) : ({mm[0]:.2f} , {mv[0]:.2f})\n{result[0, 22:28, 22:28, 0]}')

印出:
https://ithelp.ithome.com.tw/upload/images/20190925/20107299gBURQC1Le2.png

moving_mean 和 moving_var 分別是 0 和 1,根本沒有更新過,而且輸出的數值也很明顯沒有 normalization...。

希望以上給大家的示範夠詳盡,很多新手如果不清楚這件事的話,竟會發現自己爽開很多 batch normalization layer,結果模型訓練時容易失敗或發生詭異現象,就該注意一下自己有沒有用對喔!

github原始碼


上一篇
【16】tensorflow 訓練技巧:用 piecewise_constant 達成可變動 learning rate 篇
下一篇
【18】tensorflow 訓練技巧:模型如何輕鬆又方便地做 regularization 篇
系列文
How to Train Your Model 訓模高手:我的 Tensorflow 個人使用經驗系列文31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言