iT邦幫忙

2024 iThome 鐵人賽

DAY 17
0

前言

在訓練深度學習模型時,如果每次都要等待自己設定的訓練週期結束,才檢查模型在哪一個週期才是訓練最好的「最佳模型」,聽起來是一件沒有效率的事。Keras 有個好用的 Callbacks 模組,可以使用它來監控模型訓練的過程,減少手動去調整的次數。Callbacks 模組有許多功能,可以記錄訓練過程的評估指標、動態調整學習率、保存最佳模型或是設定條件停止訓練等等,例如 CSVLogger 可以將每一個訓練週期的評估指標記錄下來,LearningRateSchedulerReduceLROnPlateau 和動態調整學習率相關,ModelCheckpoint 可以依照條件保存最佳模型,EarlyStopping 可以指定多少週期沒有改善就停止訓練。

這系列我們會介紹其中 3 種:ModelCheckpointEarlyStoppingCSVLogger,今天先介紹 ModelCheckpoint

ModelCheckpoint

ModelCheckpoint 可以在訓練模型的過程中,設定檢查點以保存模型。

使用方法:

keras.callbacks.ModelCheckpoint(
    filepath,
    monitor="val_loss",
    verbose=0,
    save_best_only=False,
    save_weights_only=False,
    mode="auto",
    save_freq="epoch",
    initial_value_threshold=None,
)

參數說明:

  • filepath:儲存模型的路徑,副檔名為 .h5 或是 .keras,前者較通用但檔案較大,後者較新出現,需注意版本是否相容
  • monitor:用來監控模型的指標,通常是設定 val_loss,會依據這個指標去儲存模型
  • verbose:儲存模型的輸出訊息,設定 0 為不輸出,設定 1 為每次儲存模型會顯示訊息
  • save_best_only:設定 True 為當 monitor 設定的評估指標有改善才會儲存模型,設定 False 會儲存每一個週期的模型(預設值為 False
  • save_weights_only:設定 True 為只儲存權重,不儲存模型結構,設定 False 為儲存整個模型(預設值為 False
  • mode:監控評估指標的模式,預設為 auto,自動判斷要監控的評估指標要看的是最大值還是最小值,例如 val_loss 會看最小值,val_accuracy 會看最大值,或是直接設定對應的參數,如最大值設定為 max,最小值設定為 min
  • save_freq:設定模型儲存頻率,epoch 表示每一個週期都儲存模型,設定數值表示多少個 batch 儲存模型
  • initial_value_threshold:設定初始閾值,依照監控評估指標超過或低於設定的閾值才儲存模型,預設為 None

實作範例

ModelCheckpoint 的功能加入本系列實作程式碼。

首先匯入 ModelCheckpoint

from tensorflow.keras.callbacks import ModelCheckpoint

加入 ModelCheckpoint

checkpoint = ModelCheckpoint("model.h5", 
                              monitor="val_accuracy", 
                              verbose=1, 
                              save_best_only=True, 
                              save_weights_only=False, 
                              mode='max', 
                              period=1)

記得在 fit_generator()fit() 加入 callbacks,裡面輸入設定的變數名稱:

hist = model.fit_generator(steps_per_epoch=spe, 
                            generator=traindata, 
                            validation_data=valdata, 
                            validation_steps=vs, 
                            epochs=20,
                           callbacks=[checkpoint])

完整程式碼:

from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
from tensorflow.keras import layers, regularizers
from tensorflow.keras.applications import VGG16
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint

# 載入資料集
trpath = "./Kaggle/data/"
trdata = ImageDataGenerator(validation_split=0.2)
traindata = trdata.flow_from_directory(directory=trpath, 
                                       target_size=(256,256),
                                       shuffle=True,
                                       subset='training')
valdata = trdata.flow_from_directory(directory=trpath, 
                                     target_size=(256,256), 
                                     shuffle=True,
                                     subset='validation')
                                     
# 設定 steps_per_epoch 和 validation_steps
spe = traindata.samples // traindata.batch_size # steps_per_epoch
vs = valdata.samples // traindata.batch_size # validation_steps

# 定義資料增強層
data_augmentation = tf.keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.2),
    ]
)

# 建立模型
inputs = tf.keras.Input(shape=(256, 256, 3))
base_model = data_augmentation(inputs)
base_model = VGG16(include_top=False, weights='imagenet', input_tensor=base_model)
x = base_model.output
x = layers.Flatten()(x)
x = layers.Dense(4096, activation="relu")(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(4096, activation="relu")(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(5, activation="softmax")(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

# 凍結層
for layer in base_model.layers:
    layer.trainable = False
    
# 編譯模型
opt = Adam(learning_rate=0.0001)
model.compile(optimizer=opt, 
              loss="categorical_crossentropy", 
              metrics=["accuracy"]
             )
             
# 設定檢查點
checkpoint = ModelCheckpoint("model.h5", 
                              monitor="val_accuracy", 
                              verbose=1, 
                              save_best_only=True, 
                              save_weights_only=False, 
                              mode='max', 
                              period=1)
                              
#訓練模型
hist = model.fit_generator(steps_per_epoch=spe, 
                            generator=traindata, 
                            validation_data=valdata, 
                            validation_steps=vs, 
                            epochs=20,
                           callbacks=[checkpoint])

因為設定 verbose=1,執行結果可以觀察到下列訊息:

Epoch 00004: val_accuracy did not improve from 0.84375

顯示訓練到哪一個週期,要監控的評估指標沒有改善。
還會儲存一個 .h5 (或其他副檔名)最佳模型檔到指定的路徑。

是不是也覺得 ModelCheckpoint 很方便?明天再來介紹另一個 Callbacks 類別哦~/images/emoticon/emoticon15.gif

參考資料


上一篇
[Day 16] 模型正則化方法 (2):Dropout
下一篇
[Day 18] 回呼模組 (2):EarlyStopping
系列文
輕鬆上手AI專案-影像分類到部署模型30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言