iT邦幫忙

第 11 屆 iThome 鐵人賽

DAY 16
0
Google Developers Machine Learning

How to Train Your Model 訓模高手:我的 Tensorflow 個人使用經驗系列文系列 第 16

【16】tensorflow 訓練技巧:用 piecewise_constant 達成可變動 learning rate 篇

  • 分享至 

  • xImage
  •  

今天要介紹的是 tensorflow 中的 piecewise_constant 功能,但在介紹之前,先來介紹 global step ,global step 的概念很簡單,當我們訓練模型時,我們會對 cost function 做一次又一次的 gradient descent 而產生 train_op,而為了記錄目前是第幾次執行 train_op,我們會需要一個變數來幫忙,這個變數我們就稱作 global step。

在 tensorflow 中若要產生 global step,比較早期的做法是宣告 variable ,但其實可以直接呼叫 tf.train.get_or_create_global_step(),而這次的範例中我們沒有真正要做 gradient descent ,所以我把 update_op 當作 train_op 且定義為 global_step 加 1。

global_step = tf.train.get_or_create_global_step()
update_op = tf.assign_add(global_step, 1)

接下來說說 piecewise_constant() 吧,在做訓練時,我們會設定一個超參數叫 learning rate ,它所代表的意思為每次gradient descent 時,偏微分權重乘上的一個常數,可用於決定權重的更新量大小,如果設定的太小,模型收斂很慢,如果設定太大,模型損失值可能會陷入震盪,這個時後 piecewise_constant 會是個不錯的解法,例如:我們可以設定像是第 2000 step 以前,learning rate 為 0.1 ,這時更新速度比較快,第 2000~6000 步為 0.05,來避免震盪又可以繼續壓低損失函數,第 6000 步後降為 0.01,最後等著收斂。

有此可知,piecewise_constant 需要三個必要變數,分別為目前第幾步第幾步時變更速度各個步數區間的速度值,示範如下:

lr_steps = [5, 10, 15]
lr_values = [0.1, 0.05, 0.01, 0.001]

所代表的意思為:
0~5 步,learning rate = 0.1,
5~10 步,learning rate = 0.05,
10~15 步,learning rate = 0.01,
15 步後,learning rate = 0.001。

還記得昨天介紹的 control_dependencies 嗎?這次我們定義每次拿到的 lr_value 前,先執行 update_op 來更新 global step。

with tf.control_dependencies([update_op]):
   lr_value = tf.train.piecewise_constant(global_step, boundaries=lr_steps, values=lr_values, name='lr_rate')

最後我們就來實驗輸出結果啦:

with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
   sess.run(tf.local_variables_initializer())

   for _ in range(EPOCHS):
       gs, lr = sess.run([global_step, lr_value])

       print(f'step {gs}, lr = {lr:.3f}')

產生:
https://ithelp.ithome.com.tw/upload/images/20190924/20107299zevSvgaNAf.png

節點圖:
https://ithelp.ithome.com.tw/upload/images/20190924/20107299FfbC6oT8MB.png

這樣就可以很有效率運用會變化的 learning rate 來訓練模型。

github原始碼


上一篇
【15】tensorflow 訓練技巧:control_dependencies 篇
下一篇
【17】tensorflow 訓練技巧:運用 GraphKeys.UPDATE_OPS 來更新 Batch Normalization 權重篇
系列文
How to Train Your Model 訓模高手:我的 Tensorflow 個人使用經驗系列文31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言