iT邦幫忙

第 12 屆 iT 邦幫忙鐵人賽

DAY 7
0
AI & Data

輕鬆掌握 Keras 及相關應用系列 第 7

Day 07:Keras Callback 的使用

前言

Callback 可以在模型訓練過程中觸發事件,記錄訓練過程產生的資訊、在查核點(Checkpoint)對模型存檔、迫使訓練提早結束...等,除了可以使用內建(built-in)的Callback,也可以自制(customize)Callback。

Callback搭配許多 Keras 內建的函數,可以完全解構模型訓練的過程。
以下我們就來使用一些範例,來說明Callback功能。

內建的Callback

常用的 Callback 包括:

  1. CSVLogger:可將訓練過程記錄至 CSV 檔案。
  2. TensorBoard:這是跟 Tensorflow 結合的有利支援,將訓練過程記錄存成 TensorBoard 檔案格式,直接使用 TensorBoard 工具觀看統計圖。
  3. ModelCheckpoint:由於訓練過程耗時,有可能訓練一半就當掉,因此,我們可以利用這個 Callback,在每一個檢查點(Checkpoint)存檔,下次執行時,就可以從中斷點繼續訓練。
  4. EarlyStopping:可設定訓練提早結束的條件。
  5. LearningRateScheduler:可動態調整學習率(Learning Rate)。

其他還有:

  1. ReduceLROnPlateau:當訓練已無改善時,可以降低學習率,追求更細微的改善,找到更精準的最佳解。
  2. LambdaCallback:直接使用匿名函數自制Callback。
  3. TerminateOnNaN:當損失函數為NaN(Null value),訓練提早結束。
  4. ProgbarLogger:記錄訓練進度。

測試

我們直接拿 MNIST 辨識作各種 Callback 測試:

  1. EarlyStopping:定義 validation accuracy 三個執行週期沒改善就停止訓練
# validation loss 三個執行週期沒改善就停止訓練
my_callbacks = [
    tf.keras.callbacks.EarlyStopping(patience=3, monitor = 'val_accuracy'),
]

# 訓練 20 次
history = model.fit(x_train_norm, y_train, epochs=20, validation_split=0.2, callbacks=my_callbacks)

訓練 20 次,但實際只訓練 13次就停止了,因為連續三個執行週期validation accuracy沒改善。也可以改為 val_loss,只訓練 5 次就停止了。畫面如下:
https://ithelp.ithome.com.tw/upload/images/20200907/20001976OxKYWmRTaU.png
第2次準確度為0.9803,之後第3~5次都沒超過第2次,訓練就停止了。看圖也可以。
https://ithelp.ithome.com.tw/upload/images/20200907/20001976ldJnNuV42v.png

  1. TensorBoard
    TensorBoard 是 Tensorflow 提供的視覺化工具,功能非常強大,除了可以顯示訓練的過程,也可以顯示圖片及語音。在訓練的過程中就可以啟動TensorBoard,即時觀看訓練資訊。
# 定義 tensorboard callback
tensorboard_callback = [tf.keras.callbacks.TensorBoard(log_dir='.\\logs')]

# 訓練 10 次
history = model.fit(x_train_norm, y_train, epochs=10, validation_split=0.2, callbacks=tensorboard_callback)

開啟 cmd/終端機,執行 tensorboard --logdir=.\logs,啟動網頁伺服器,再使用瀏覽器輸入以下網址,即可觀看訓練資訊:
http://localhost:6006/

相關資訊如下:

  • 【Scalars】頁籤:顯示準確度與損失函數線圖
    https://ithelp.ithome.com.tw/upload/images/20200907/20001976EfEkR2rcH4.png

  • 【Graphs】頁籤:顯示運算圖(Graphs)
    https://ithelp.ithome.com.tw/upload/images/20200907/2000197660yOao0l6k.png

  1. ModelCheckpoint:在每一個檢查點(Checkpoint)存檔。
# 定義 ModelCheckpoint callback
checkpoint_filepath = '.\\tmp\\checkpoint'
model_checkpoint_callback = [tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_filepath, save_weights_only=True)]

# 訓練 10 次
model.fit(x_train_norm, y_train, epochs=10, validation_split=0.2, callbacks=model_checkpoint_callback)

下次要從最近的檢查點開始繼續訓練,如下:

# 載入最近的檢查點的權重
model.load_weights(checkpoint_filepath)
# 訓練 5 次
model.fit(x_train_norm, y_train, epochs=5, validation_split=0.2, callbacks=model_checkpoint_callback)

我們可以看到準確率(accuracy)會接續上次繼續提升,而不是回到第1次訓練時的準確率。

結論

以上我們測了幾種常用的 callback,注意,model.fit 的參數callbacks值是一個list,可以一次加入多個callback,至於如何將更多資訊放入,就靜待下回分曉了。

本篇範例包括07_01_Callback.ipynb,可自【這裡】下載。


上一篇
Day 06:Keras 模型結構
下一篇
Day 08:TensorBoard 的初體驗
系列文
輕鬆掌握 Keras 及相關應用30

1 則留言

0
frankyeh
iT邦新手 5 級 ‧ 2020-12-28 09:13:15

請問如果accuracy到達某個值之後停止訓練,callback是否有支援?

可以的。
EarlyStopping(monitor='val_accuracy', mode='max', min_delta=1)

我要留言

立即登入留言