iT邦幫忙

第 12 屆 iThome 鐵人賽

DAY 18
0
AI & Data

AI從入門到放棄系列 第 18

Day 18 ~ AI從入門到放棄 - 訓練回調

  • 分享至 

  • xImage
  •  

今天介紹Keras提供的好用工具,幫助我們在訓練時,針對模型的狀態來做出反應,或單純建立Log。他的用法相當簡單,先建立出物件,在訓練的時候傳遞給它就好了,我們看看範例。

from tensorflow.keras.callbacks import EarlyStopping

es = EarlyStopping(monitor='val_loss', patience=3, mode='auto', restore_best_weights=True)

model.fit(
  x = x_train,
  y = y_train,
  batch_size = 32,
  epochs = 20,
  validation_split = 0.1,
  callbacks = [mcp, log, ton],
  verbose = 2
)

前面我們說到防止過擬合的一個方法是提早停止訓練,在這個例子裡,如果模型經過連續3輪,val_loss都沒有改善,就停止訓練,並且將權重回退至該輪的狀態。你可以在mode填入min,但是如果你監控的是val_accuracy,就必須使用max,建議保持auto,讓它自由挑選,你也可以在callbacks加入兩個物件,一個監控val_loss,一個監控val_accuracy。

分享幾個好用的callbacks。

  • ModelCheckpoint:此例子中只要val_loss有變好,就將權重存檔,檔名可以使用format字串。
  • CSVLogger:以csv格式保存訓練的loss和accuracy等資訊到檔案。
  • TerminateOnNaN:loss變成NaN時停止,通常發生在學習率過大開始發散時。
  • LearningRateScheduler:允許你自訂函數,動態調整學習率,此例在10個epoch前,學習率不變,10之後的每一個epoch學習率變之前的一半。
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, TerminateOnNaN, LearningRateScheduler

mcp = ModelCheckpoint(filepath='mnist-{epoch:02d}.h5', monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=False, mode='auto', save_freq='epoch')
log = CSVLogger(filename='mnist.csv', separator=',', append=False)
ton = TerminateOnNaN()
lrs = LearningRateScheduler(lambda epoch, lr: lr if epoch < 10 else lr / 2)

給使用Colab的同學的小指引,你的工作目錄預設在/content,如果你有掛載你的雲端硬碟,它會被掛載到/content/drive/My Drive,因此想要存檔到雲端硬碟的話,可以這麼寫。

log = CSVLogger(filename='/content/drive/My Drive/mnist.csv', separator=',', append=False)

上一篇
Day 17 ~ AI從入門到放棄 - 資料增強
下一篇
Day 19 ~ AI從入門到放棄 - 應用到目前為止所學到的技巧
系列文
AI從入門到放棄30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言