iT邦幫忙

2024 iThome 鐵人賽

DAY 28
0
AI/ ML & Data

輕鬆上手AI專案-影像分類到部署模型系列 第 28

[Day 28] 輸入訓練資料的另一種方法

  • 分享至 

  • xImage
  •  

前言

輸入訓練資料集的方法不只有一種,有時候取決於輸入資料的格式、套件的使用、資料處理的方法或模型的架構等等,例如同樣是影像作為輸入,就可以選擇不同的資料輸入方法。今天以輸入影像為例,介紹其中兩種輸入訓練資料的方法,並帶給大家一些深度學習的觀念。

image_dataset_from_directory()

第一種方式為使用 tensorflow.keras.preprocessing 中的 image_dataset_from_directory(),除了可以沿用舊版的 fit_generator(),也可以使用新版的 fit()

使用方法

使用 image_dataset_from_directory() 輸入訓練資料集:

from tensorflow.keras.preprocessing import image_dataset_from_directory

trpath = "your_dataset_path" # 訓練資料集路徑

# 定義訓練資料集
traindata = image_dataset_from_directory(
    trpath,
    validation_split=0.2,
    seed=42,
    label_mode="categorical",
    subset="training",
)
# 定義驗證資料集
valdata = image_dataset_from_directory(
    trpath,
    validation_split=0.2,
    seed=42,
    label_mode="categorical",
    subset="validation"
)

# 設定重複
traindata = traindata.repeat()
valdata = valdata.repeat()

參數說明

使用方法和 flow_from_directory() 類似,先指定資料集路徑來源(這裡訓練資料集和驗證資料集皆出自於 trpath,所以兩者設定相同路徑)。validation_split 用來切分驗證資料集,0.2 表示 20% 訓練資料集作為驗證資料集。seed 表示隨機種子,如果要進行多次訓練做比較,建議設定一個數字,每次隨機值才會相同。label_mode 設定為 categorical,表示標籤編碼為 One-hot Encoding,要注意損失函數的選擇有沒有對應。subset 設定為哪一個子集(訓練資料集或驗證資料集)。
(參數可以依照自己需求設定)

值得注意

flow_from_directory() 不同的是,需要設定 repeat(),因為 image_dataset_from_directory() 沒有循環的功能,需要重複才不會讓數據耗盡。

使用 fit() 訓練

fit_generator() 改寫成 fit(),在這裡只需要更改訓練方法名稱,就可以直接使用:

hist = model.fit(traindata,
                 steps_per_epoch=spe,
                 validation_data=valdata,
                 validation_steps=vs,
                 epochs=100,
                 callbacks=[checkpoint, csv_logger, early]
)

以 numpy.ndarray 型態作為輸入

第二種方式,也可以將影像檔案轉成 numpy.ndarray 型態,即多維陣列 N-dimensional Array。將影像檔一個一個讀取存成 Array,也要指定對應的標籤:

import os
import numpy as np
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from sklearn.model_selection import train_test_split

trpath = "your_dataset_path" # 訓練資料集路徑

# 定義訓練資料集路徑
filelist = []
black_path = "./data/black/"
grizzly_path = "./data/grizzly/"
panda_path = "./data/panda/"
polar_path = "./data/polar/"
teddy_path = "./data/teddy/"
# 定義取得完整路徑的函式
def get_full_paths(directory):
    return [os.path.join(directory, file) for file in os.listdir(directory)]
# 使用 extend() 添加訓練資料完整路徑
filelist.extend(get_full_paths(black_path))
filelist.extend(get_full_paths(grizzly_path))
filelist.extend(get_full_paths(panda_path))
filelist.extend(get_full_paths(polar_path))
filelist.extend(get_full_paths(teddy_path))

# 定義處理後的訓練資料集 x
x = []
for i in range(len(filelist)):
    img = load_img(filelist[i], target_size=(256, 256))
    img_array = img_to_array(img)
    x.append(img_array)
# 將變數類型重塑成符合模型輸入的變數類型
x = np.array(x).reshape(len(filelist), 256, 256, 3)
    
# 定義類別(標籤)
y = [0]*len(os.listdir(black_path))+[1]*len(os.listdir(grizzly_path))+[2]*len(os.listdir(panda_path))+[3]*len(os.listdir(polar_path))+[4]*len(os.listdir(teddy_path))
y = np.array(y).reshape(len(filelist),)

# 使用 train_test_split() 切分訓練集和驗證集
x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=42)

程式說明與補充

  • 使用 os.path.join() 的原因是因為 os.listdir() 只有將目錄下所有檔案名稱列出來(如 black1.jpg),而這裡需要完整的路徑(如 ./data/black/black1.jpg)。
  • 之前有使用過 append(),這裡使用 extend(),差別在前者會將附加的東西當作一個元素添加,例如列表會直接附加進去(變成列表中還有列表),而後者會逐一添加,所以列表會每一個元素分開添加。
  • 訓練資料集 x 和標籤 y 都有使用 reshape(),以符合深度學習訓練資料的 Shape,如影像常使用 shape = (樣本數, 長, 寬, 通道)(例如 RGB 影像通道為 3),標籤 shape=(樣本數,)
  • 這裡使用 sklearn.model_selectiontrain_test_split 來切分訓練資料集與驗證資料集,test_size 為驗證資料集的比例,0.2 表示資料集 20% 作為驗證資料集,random_state 類似於 seed,設定數值可以確定每次切分的結果是相同的。

使用 fit() 訓練

這裡就不能使用 fit_generator() 了,因為不是使用生成器做為訓練資料集的輸入,要使用 fit() 來訓練,寫法也稍微不同:

hist = model.fit(x_train, 
                  y_train, 
                  epochs=100, 
                  batch_size=32, 
                  validation_data=(x_val, y_val), 
                  callbacks=[csv_logger, checkpoint, early])

設定訓練資料集 x_train 和訓練集標籤 y_train,以及 epochs 為週期數,batch_size 為批次大小,validation_data 為驗證資料集,放入驗證資料集 x_val 與對應標籤 y_val。

值得注意

因為標籤編碼為整數數值,若沒有轉換為 One-hot Encoding 的編碼形式,損失函數要記得更改為 sparse_categorical_crossentropy 才能訓練。

model.compile(optimizer=opt, 
              loss="sparse_categorical_crossentropy", 
              metrics=["accuracy"])

如果是處理數據檔案(例如 .npy 檔案),就可以使用這個方法來處理輸入資料集。

參考資料

  • François Chollet《Deep Learning with Python, Second Edition》(中文版:黃逸華、林采薇譯《Keras大神歸位:深度學習全面進化!用Python 實作CNN、RNN、GRU、LSTM、GAN、VAE、Transformer》,旗標出版)

上一篇
[Day 27] 好用的日誌工具:loguru
下一篇
[Day 29] 使用爬蟲技術蒐集圖片
系列文
輕鬆上手AI專案-影像分類到部署模型30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言