iT邦幫忙

2024 iThome 鐵人賽

DAY 18
0

前言

訓練模型時,可能會遇到訓練至某個週期開始,驗證資料集的評估指標(或其他評估指標)沒有往上升,這時可以使用 EarlyStopping 來讓模型停止訓練。

EarlyStopping

EarlyStoppingkeras.callbacks 模組中的一個類別,可以設定模型在訓練時,經過多少個週期評估指標沒有進步,就停止訓練,可以用來防止 Overfitting。

使用方法:

keras.callbacks.EarlyStopping(
    monitor="val_loss",
    min_delta=0,
    patience=0,
    verbose=0,
    mode="auto",
    baseline=None,
    restore_best_weights=False,
    start_from_epoch=0,
)

參數說明

  • monitor:要監控的評估指標,通常是設定 val_loss
  • min_delta:最小變化值,當評估指標的變化小於此值表示沒有改善
  • patience:耐心值,設定模型經過多少週期沒有改善就停止訓練
  • verbose:輸出模式,設定 1 會顯示 Early Stopping 相關訊息
  • mode:監控評估指標變化方向,預設為 auto (自動判斷),依據是損失值(通常看最小值)或是準確度(通常看最大值)來選擇 minmax
  • baseline:基準線,可以設定一個數值,會依據 monitor 設定的評估指標,若沒有達到設定數值會停止訓練
  • restore_best_weights:若設定為 True,在訓練結束時會回復到訓練過程中表現最佳的權重
  • start_from_epoch:設定第幾個週期開始使用 EarlyStopping

實作範例

EarlyStopping 的功能加入本系列實作程式碼。

首先匯入 EarlyStopping

from tensorflow.keras.callbacks import EarlyStopping

加入 EarlyStopping

early = EarlyStopping(monitor="val_accuracy", 
                      min_delta=0, 
                      patience=20, 
                      verbose=1, 
                      mode="max")

這裡因為使用監控評估指標為 val_accuracy,所以對應 mode 設為 maxmin_delta 設定為 0 表示只要有改善(變化數值即使非常小)就是有進步,patience 設定為 20 表示經過 20 個週期都沒有進步就停止訓練。

記得在 fit_generator()fit() 加入 callbacks,裡面輸入設定的變數名稱:

hist = model.fit_generator(steps_per_epoch=spe, 
                            generator=traindata, 
                            validation_data=valdata, 
                            validation_steps=vs, 
                            epochs=100,
                            callbacks=[checkpoint, early])

完整程式碼:

from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
from tensorflow.keras import layers, regularizers
from tensorflow.keras.applications import VGG16
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

# 載入資料集
trpath = "./Kaggle/data/"
trdata = ImageDataGenerator(validation_split=0.2)
traindata = trdata.flow_from_directory(directory=trpath, 
                                       target_size=(256,256),
                                       shuffle=True,
                                       subset='training')
valdata = trdata.flow_from_directory(directory=trpath, 
                                     target_size=(256,256), 
                                     shuffle=True,
                                     subset='validation')

# 定義 steps_per_epoch 和 validation_steps
spe = traindata.samples // traindata.batch_size # steps_per_epoch
vs = valdata.samples // traindata.batch_size # validation_steps

# 定義資料增強層
data_augmentation = tf.keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.2),
    ]
)

# 建立模型
inputs = tf.keras.Input(shape=(256, 256, 3))
base_model = data_augmentation(inputs)
base_model = VGG16(include_top=False, weights='imagenet', input_tensor=base_model)
x = base_model.output
x = layers.Flatten()(x)
x = layers.Dense(4096, activation="relu")(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(4096, activation="relu")(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(5, activation="softmax")(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

# 凍結層
for layer in base_model.layers:
    layer.trainable = False

# 編譯模型
opt = Adam(learning_rate=0.0001)
model.compile(optimizer=opt, 
              loss="categorical_crossentropy", 
              metrics=["accuracy"]
             )
checkpoint = ModelCheckpoint("model.h5", 
                              monitor="val_accuracy", 
                              verbose=0, 
                              save_best_only=True, 
                              save_weights_only=False, 
                              mode='max', 
                              period=1)
early = EarlyStopping(monitor="val_accuracy", 
                      min_delta=0, 
                      patience=20, 
                      verbose=1, 
                      mode="max")
hist = model.fit_generator(steps_per_epoch=spe, 
                            generator=traindata, 
                            validation_data=valdata, 
                            validation_steps=vs, 
                            epochs=100,
                           callbacks=[checkpoint, early])

執行結果:

Epoch 1/100
7/7 [==============================] - 1s 67ms/step - loss: 18.4279 - accuracy: 0.4028 - val_loss: 4.9985 - val_accuracy: 0.6562
Epoch 2/100
7/7 [==============================] - 0s 49ms/step - loss: 6.7101 - accuracy: 0.7083 - val_loss: 1.7502 - val_accuracy: 0.8750
Epoch 3/100
7/7 [==============================] - 0s 48ms/step - loss: 4.2408 - accuracy: 0.8194 - val_loss: 2.8444 - val_accuracy: 0.9062
Epoch 4/100
7/7 [==============================] - 0s 63ms/step - loss: 3.4256 - accuracy: 0.8661 - val_loss: 5.2930 - val_accuracy: 0.8438
Epoch 5/100
7/7 [==============================] - 0s 49ms/step - loss: 3.0273 - accuracy: 0.8795 - val_loss: 0.0107 - val_accuracy: 1.0000
Epoch 6/100
7/7 [==============================] - 0s 71ms/step - loss: 1.4176 - accuracy: 0.9259 - val_loss: 2.2335 - val_accuracy: 0.8750
Epoch 7/100
7/7 [==============================] - 1s 81ms/step - loss: 1.8321 - accuracy: 0.9167 - val_loss: 0.4204 - val_accuracy: 0.9688
Epoch 8/100
7/7 [==============================] - 1s 81ms/step - loss: 2.1764 - accuracy: 0.9018 - val_loss: 1.4816 - val_accuracy: 0.9375
Epoch 9/100
7/7 [==============================] - 1s 86ms/step - loss: 1.0331 - accuracy: 0.9375 - val_loss: 1.7004 - val_accuracy: 0.9375
Epoch 10/100
7/7 [==============================] - 1s 78ms/step - loss: 1.3516 - accuracy: 0.9464 - val_loss: 3.2331 - val_accuracy: 0.9062
Epoch 11/100
7/7 [==============================] - 1s 78ms/step - loss: 0.9099 - accuracy: 0.9491 - val_loss: 2.9501 - val_accuracy: 0.8438
Epoch 12/100
7/7 [==============================] - 1s 76ms/step - loss: 1.4062 - accuracy: 0.9398 - val_loss: 4.0914 - val_accuracy: 0.9375
Epoch 13/100
7/7 [==============================] - 1s 79ms/step - loss: 1.6024 - accuracy: 0.9306 - val_loss: 2.5077 - val_accuracy: 0.9688
Epoch 14/100
7/7 [==============================] - 1s 83ms/step - loss: 2.8048 - accuracy: 0.9352 - val_loss: 2.8890 - val_accuracy: 0.9062
Epoch 15/100
7/7 [==============================] - 1s 77ms/step - loss: 0.7819 - accuracy: 0.9583 - val_loss: 8.3409 - val_accuracy: 0.7812
Epoch 16/100
7/7 [==============================] - 1s 85ms/step - loss: 1.9119 - accuracy: 0.9398 - val_loss: 1.3450 - val_accuracy: 0.9375
Epoch 17/100
7/7 [==============================] - 1s 81ms/step - loss: 0.6153 - accuracy: 0.9676 - val_loss: 3.8260 - val_accuracy: 0.8438
Epoch 18/100
7/7 [==============================] - 1s 78ms/step - loss: 0.9660 - accuracy: 0.9676 - val_loss: 7.5690 - val_accuracy: 0.8438
Epoch 19/100
7/7 [==============================] - 1s 81ms/step - loss: 0.7829 - accuracy: 0.9815 - val_loss: 0.5642 - val_accuracy: 0.9688
Epoch 20/100
7/7 [==============================] - 1s 80ms/step - loss: 0.3769 - accuracy: 0.9861 - val_loss: 9.0406 - val_accuracy: 0.8750
Epoch 21/100
7/7 [==============================] - 1s 83ms/step - loss: 0.6738 - accuracy: 0.9815 - val_loss: 6.4338 - val_accuracy: 0.9062
Epoch 22/100
7/7 [==============================] - 1s 83ms/step - loss: 0.8787 - accuracy: 0.9722 - val_loss: 0.6963 - val_accuracy: 0.9375
Epoch 23/100
7/7 [==============================] - 1s 80ms/step - loss: 0.8309 - accuracy: 0.9769 - val_loss: 6.1115 - val_accuracy: 0.9062
Epoch 24/100
7/7 [==============================] - 1s 85ms/step - loss: 0.2586 - accuracy: 0.9866 - val_loss: 4.7981 - val_accuracy: 0.8438
Epoch 25/100
7/7 [==============================] - 1s 81ms/step - loss: 0.6755 - accuracy: 0.9537 - val_loss: 2.2561 - val_accuracy: 0.9062
Epoch 00025: early stopping

可以看到在第 25 個週期就停止訓練了,不用跑到預設的 100 次,這樣就不會讓模型繼續學習,可能還學僵啦~(也減少了訓練成本)

是不是也覺得 EarlyStopping 很神奇呢?(這句話好像昨天也說過類似的呢)但是真的很有趣!明天會介紹本系列最後一個 Callbacks 類別~

參考資料


上一篇
[Day 17] 回呼模組 (1):ModelCheckpoint
下一篇
[Day 19] 回呼模組 (3):CSVLogger
系列文
輕鬆上手AI專案-影像分類到部署模型30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言