今天介紹Keras提供的好用工具,幫助我們在訓練時,針對模型的狀態來做出反應,或單純建立Log。他的用法相當簡單,先建立出物件,在訓練的時候傳遞給它就好了,我們看看範例。
from tensorflow.keras.callbacks import EarlyStopping
es = EarlyStopping(monitor='val_loss', patience=3, mode='auto', restore_best_weights=True)
model.fit(
x = x_train,
y = y_train,
batch_size = 32,
epochs = 20,
validation_split = 0.1,
callbacks = [mcp, log, ton],
verbose = 2
)
前面我們說到防止過擬合的一個方法是提早停止訓練,在這個例子裡,如果模型經過連續3輪,val_loss都沒有改善,就停止訓練,並且將權重回退至該輪的狀態。你可以在mode填入min,但是如果你監控的是val_accuracy,就必須使用max,建議保持auto,讓它自由挑選,你也可以在callbacks加入兩個物件,一個監控val_loss,一個監控val_accuracy。
分享幾個好用的callbacks。
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, TerminateOnNaN, LearningRateScheduler
mcp = ModelCheckpoint(filepath='mnist-{epoch:02d}.h5', monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=False, mode='auto', save_freq='epoch')
log = CSVLogger(filename='mnist.csv', separator=',', append=False)
ton = TerminateOnNaN()
lrs = LearningRateScheduler(lambda epoch, lr: lr if epoch < 10 else lr / 2)
給使用Colab的同學的小指引,你的工作目錄預設在/content,如果你有掛載你的雲端硬碟,它會被掛載到/content/drive/My Drive,因此想要存檔到雲端硬碟的話,可以這麼寫。
log = CSVLogger(filename='/content/drive/My Drive/mnist.csv', separator=',', append=False)