iT邦幫忙

2024 iThome 鐵人賽

DAY 19
0

前言

前兩天介紹 Callbacks 模組中的兩個類別,今天要來介紹 CSVLogger,也是可以用來監控模型在訓練過程中的評估指標。和前兩天不同,CSVLogger 可以將這些評估指標記錄下來,儲存成一份檔案,就可以不用將訓練過程複製到筆記本,或是不小心關掉視窗難以查找所需資訊啦~

CSVLogger

CSVLogger 為 keras.callbacks 模組中的一個類別,用來記錄訓練過程中,每一個週期的評估指標,並將其儲存成檔案,供日後有需要的時候查看,或是將這些數值繪製成圖表等。

使用方法:

keras.callbacks.CSVLogger(filename, separator=",", append=False)

參數說明:

  • filename:用來儲存日誌的檔案名稱(或完整路徑),副檔名除了是 .csv,也可以使用 .txt.log
  • separator:分隔 CSV 文件的字符,預設為逗點 ,
  • append:是否將新數據附加至現有文件後,設定 True 會附加於現有數據後,False 會覆蓋現有的數據

實作範例

CSVLogger 的功能加入本系列實作程式碼。

首先匯入 CSVLogger

from tensorflow.keras.callbacks import CSVLogger

加入 CSVLogger

csv_logger = CSVLogger("training.csv", append=True)

記得在 fit_generator()fit() 加入 callbacks,裡面輸入設定的變數名稱:

hist = model.fit_generator(steps_per_epoch=spe, 
                            generator=traindata, 
                            validation_data=valdata, 
                            validation_steps=vs, 
                            epochs=100,
                           callbacks=[checkpoint, early, csv_logger])

完整程式碼:

from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
from tensorflow.keras import layers, regularizers
from tensorflow.keras.applications import VGG16
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger

# 載入資料集
trpath = "./Kaggle/data/"
trdata = ImageDataGenerator(validation_split=0.2)
traindata = trdata.flow_from_directory(directory=trpath, 
                                       target_size=(256,256),
                                       shuffle=True,
                                       subset='training')
valdata = trdata.flow_from_directory(directory=trpath, 
                                     target_size=(256,256), 
                                     shuffle=True,
                                     subset='validation')

# 設定 steps_per_epoch 和 validation_steps
spe = traindata.samples // traindata.batch_size # steps_per_epoch
vs = valdata.samples // traindata.batch_size # validation_steps

# 定義資料增強層
data_augmentation = tf.keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.2),
    ]
)

# 建立模型
inputs = tf.keras.Input(shape=(256, 256, 3))
base_model = data_augmentation(inputs)
base_model = VGG16(include_top=False, weights='imagenet', input_tensor=base_model)
x = base_model.output
x = layers.Flatten()(x)
x = layers.Dense(4096, activation="relu")(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(4096, activation="relu")(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(5, activation="softmax")(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

# 凍結層
for layer in base_model.layers:
    layer.trainable = False

# 編譯模型
opt = Adam(learning_rate=0.0001)
model.compile(optimizer=opt, 
              loss="categorical_crossentropy", 
              metrics=["accuracy"]
             )
checkpoint = ModelCheckpoint("model.h5", 
                              monitor="val_accuracy", 
                              verbose=0, 
                              save_best_only=True, 
                              save_weights_only=False, 
                              mode='max', 
                              period=1)
early = EarlyStopping(monitor="val_accuracy", 
                      min_delta=0, 
                      patience=20, 
                      verbose=1, 
                      mode="max")
csv_logger = CSVLogger("training.csv", append=True)
hist = model.fit_generator(steps_per_epoch=spe, 
                            generator=traindata, 
                            validation_data=valdata, 
                            validation_steps=vs, 
                            epochs=100,
                           callbacks=[checkpoint, early, csv_logger])

執行後會得到一份 CSV 檔案,記錄了訓練過程的評估指標:
https://ithelp.ithome.com.tw/upload/images/20240924/20166645UB7H3S4rN5.png

因為設定了 append=True,再訓練一次模型,檔案內會看到新的數據附加在剛才訓練的結果之後:
https://ithelp.ithome.com.tw/upload/images/20240924/20166645V3yLUGeO1Q.png

是不是很方便呢?Callbacks 模組還有其他好用的功能,都可以嘗試看看。

明天要來開始畫圖啦!/images/emoticon/emoticon42.gif

參考資料


上一篇
[Day 18] 回呼模組 (2):EarlyStopping
下一篇
[Day 20] 訓練結果可視化
系列文
輕鬆上手AI專案-影像分類到部署模型30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言