訓練模型時,可能會遇到訓練至某個週期開始,驗證資料集的評估指標(或其他評估指標)沒有往上升,這時可以使用 EarlyStopping
來讓模型停止訓練。
EarlyStopping
是 keras.callbacks
模組中的一個類別,可以設定模型在訓練時,經過多少個週期評估指標沒有進步,就停止訓練,可以用來防止 Overfitting。
使用方法:
keras.callbacks.EarlyStopping(
monitor="val_loss",
min_delta=0,
patience=0,
verbose=0,
mode="auto",
baseline=None,
restore_best_weights=False,
start_from_epoch=0,
)
monitor
:要監控的評估指標,通常是設定 val_loss
min_delta
:最小變化值,當評估指標的變化小於此值表示沒有改善
patience
:耐心值,設定模型經過多少週期沒有改善就停止訓練verbose
:輸出模式,設定 1
會顯示 Early Stopping
相關訊息mode
:監控評估指標變化方向,預設為 auto
(自動判斷),依據是損失值(通常看最小值)或是準確度(通常看最大值)來選擇 min
或 max
baseline
:基準線,可以設定一個數值,會依據 monitor
設定的評估指標,若沒有達到設定數值會停止訓練restore_best_weights
:若設定為 True
,在訓練結束時會回復到訓練過程中表現最佳的權重start_from_epoch
:設定第幾個週期開始使用 EarlyStopping
將 EarlyStopping
的功能加入本系列實作程式碼。
首先匯入 EarlyStopping
:
from tensorflow.keras.callbacks import EarlyStopping
加入 EarlyStopping
:
early = EarlyStopping(monitor="val_accuracy",
min_delta=0,
patience=20,
verbose=1,
mode="max")
這裡因為使用監控評估指標為 val_accuracy
,所以對應 mode
設為 max
,min_delta
設定為 0 表示只要有改善(變化數值即使非常小)就是有進步,patience
設定為 20 表示經過 20 個週期都沒有進步就停止訓練。
記得在 fit_generator()
或 fit()
加入 callbacks
,裡面輸入設定的變數名稱:
hist = model.fit_generator(steps_per_epoch=spe,
generator=traindata,
validation_data=valdata,
validation_steps=vs,
epochs=100,
callbacks=[checkpoint, early])
完整程式碼:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
from tensorflow.keras import layers, regularizers
from tensorflow.keras.applications import VGG16
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
# 載入資料集
trpath = "./Kaggle/data/"
trdata = ImageDataGenerator(validation_split=0.2)
traindata = trdata.flow_from_directory(directory=trpath,
target_size=(256,256),
shuffle=True,
subset='training')
valdata = trdata.flow_from_directory(directory=trpath,
target_size=(256,256),
shuffle=True,
subset='validation')
# 定義 steps_per_epoch 和 validation_steps
spe = traindata.samples // traindata.batch_size # steps_per_epoch
vs = valdata.samples // traindata.batch_size # validation_steps
# 定義資料增強層
data_augmentation = tf.keras.Sequential(
[
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
layers.RandomZoom(0.2),
]
)
# 建立模型
inputs = tf.keras.Input(shape=(256, 256, 3))
base_model = data_augmentation(inputs)
base_model = VGG16(include_top=False, weights='imagenet', input_tensor=base_model)
x = base_model.output
x = layers.Flatten()(x)
x = layers.Dense(4096, activation="relu")(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(4096, activation="relu")(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(5, activation="softmax")(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
# 凍結層
for layer in base_model.layers:
layer.trainable = False
# 編譯模型
opt = Adam(learning_rate=0.0001)
model.compile(optimizer=opt,
loss="categorical_crossentropy",
metrics=["accuracy"]
)
checkpoint = ModelCheckpoint("model.h5",
monitor="val_accuracy",
verbose=0,
save_best_only=True,
save_weights_only=False,
mode='max',
period=1)
early = EarlyStopping(monitor="val_accuracy",
min_delta=0,
patience=20,
verbose=1,
mode="max")
hist = model.fit_generator(steps_per_epoch=spe,
generator=traindata,
validation_data=valdata,
validation_steps=vs,
epochs=100,
callbacks=[checkpoint, early])
執行結果:
Epoch 1/100
7/7 [==============================] - 1s 67ms/step - loss: 18.4279 - accuracy: 0.4028 - val_loss: 4.9985 - val_accuracy: 0.6562
Epoch 2/100
7/7 [==============================] - 0s 49ms/step - loss: 6.7101 - accuracy: 0.7083 - val_loss: 1.7502 - val_accuracy: 0.8750
Epoch 3/100
7/7 [==============================] - 0s 48ms/step - loss: 4.2408 - accuracy: 0.8194 - val_loss: 2.8444 - val_accuracy: 0.9062
Epoch 4/100
7/7 [==============================] - 0s 63ms/step - loss: 3.4256 - accuracy: 0.8661 - val_loss: 5.2930 - val_accuracy: 0.8438
Epoch 5/100
7/7 [==============================] - 0s 49ms/step - loss: 3.0273 - accuracy: 0.8795 - val_loss: 0.0107 - val_accuracy: 1.0000
Epoch 6/100
7/7 [==============================] - 0s 71ms/step - loss: 1.4176 - accuracy: 0.9259 - val_loss: 2.2335 - val_accuracy: 0.8750
Epoch 7/100
7/7 [==============================] - 1s 81ms/step - loss: 1.8321 - accuracy: 0.9167 - val_loss: 0.4204 - val_accuracy: 0.9688
Epoch 8/100
7/7 [==============================] - 1s 81ms/step - loss: 2.1764 - accuracy: 0.9018 - val_loss: 1.4816 - val_accuracy: 0.9375
Epoch 9/100
7/7 [==============================] - 1s 86ms/step - loss: 1.0331 - accuracy: 0.9375 - val_loss: 1.7004 - val_accuracy: 0.9375
Epoch 10/100
7/7 [==============================] - 1s 78ms/step - loss: 1.3516 - accuracy: 0.9464 - val_loss: 3.2331 - val_accuracy: 0.9062
Epoch 11/100
7/7 [==============================] - 1s 78ms/step - loss: 0.9099 - accuracy: 0.9491 - val_loss: 2.9501 - val_accuracy: 0.8438
Epoch 12/100
7/7 [==============================] - 1s 76ms/step - loss: 1.4062 - accuracy: 0.9398 - val_loss: 4.0914 - val_accuracy: 0.9375
Epoch 13/100
7/7 [==============================] - 1s 79ms/step - loss: 1.6024 - accuracy: 0.9306 - val_loss: 2.5077 - val_accuracy: 0.9688
Epoch 14/100
7/7 [==============================] - 1s 83ms/step - loss: 2.8048 - accuracy: 0.9352 - val_loss: 2.8890 - val_accuracy: 0.9062
Epoch 15/100
7/7 [==============================] - 1s 77ms/step - loss: 0.7819 - accuracy: 0.9583 - val_loss: 8.3409 - val_accuracy: 0.7812
Epoch 16/100
7/7 [==============================] - 1s 85ms/step - loss: 1.9119 - accuracy: 0.9398 - val_loss: 1.3450 - val_accuracy: 0.9375
Epoch 17/100
7/7 [==============================] - 1s 81ms/step - loss: 0.6153 - accuracy: 0.9676 - val_loss: 3.8260 - val_accuracy: 0.8438
Epoch 18/100
7/7 [==============================] - 1s 78ms/step - loss: 0.9660 - accuracy: 0.9676 - val_loss: 7.5690 - val_accuracy: 0.8438
Epoch 19/100
7/7 [==============================] - 1s 81ms/step - loss: 0.7829 - accuracy: 0.9815 - val_loss: 0.5642 - val_accuracy: 0.9688
Epoch 20/100
7/7 [==============================] - 1s 80ms/step - loss: 0.3769 - accuracy: 0.9861 - val_loss: 9.0406 - val_accuracy: 0.8750
Epoch 21/100
7/7 [==============================] - 1s 83ms/step - loss: 0.6738 - accuracy: 0.9815 - val_loss: 6.4338 - val_accuracy: 0.9062
Epoch 22/100
7/7 [==============================] - 1s 83ms/step - loss: 0.8787 - accuracy: 0.9722 - val_loss: 0.6963 - val_accuracy: 0.9375
Epoch 23/100
7/7 [==============================] - 1s 80ms/step - loss: 0.8309 - accuracy: 0.9769 - val_loss: 6.1115 - val_accuracy: 0.9062
Epoch 24/100
7/7 [==============================] - 1s 85ms/step - loss: 0.2586 - accuracy: 0.9866 - val_loss: 4.7981 - val_accuracy: 0.8438
Epoch 25/100
7/7 [==============================] - 1s 81ms/step - loss: 0.6755 - accuracy: 0.9537 - val_loss: 2.2561 - val_accuracy: 0.9062
Epoch 00025: early stopping
可以看到在第 25 個週期就停止訓練了,不用跑到預設的 100 次,這樣就不會讓模型繼續學習,可能還學僵啦~(也減少了訓練成本)
是不是也覺得 EarlyStopping
很神奇呢?(這句話好像昨天也說過類似的呢)但是真的很有趣!明天會介紹本系列最後一個 Callbacks 類別~