iT邦幫忙

2024 iThome 鐵人賽

DAY 11
1
AI/ ML & Data

輕鬆上手AI專案-影像分類到部署模型系列 第 11

[Day 11] 訓練模型的方法:遷移學習

  • 分享至 

  • xImage
  •  

前言

通常在訓練深度學習模型,會拿準備的資料集去做從頭訓練(Train from scratch),但有時候我們會遇到一些狀況,例如我們收集的資料集數量較稀少,或是資料集的主題較為冷門,這時候會使用遷移學習(Transfer Learning)來訓練模型。

遷移學習

遷移學習簡單來說是在某個領域訓練的模型,可以應用在另一個領域的資料集上,並解決問題,例如想要做一個卡車分類模型,而已經有許多汽車分類的模型,我們就可以拿來運用在分類卡車上。遷移學習的概念是依照人類學習事物的方式,我們在學習新事物的時候,會依賴以前的一些經驗(學過的東西),所以如果模型已經學會分類汽車種類,也有可能有能力去分類卡車種類。遷移學習也有其他好處,因為是利用載入預訓練模型(Pretrained Model),所以會減少計算數據的時間。預訓練模型指的是在其他資料集已經訓練過的模型,例如在 ImageNet 這個大型影像資料集訓練過,擁有 1000 多種影像分類的辨識能力,就會有較為通則的概念,可以用來解決通用的問題,透過微調(Fine-tuning)的方式來符合自身準備的資料集需求。

使用方法

回顧昨天建構的模型:

import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras import layers

# use data augmentation layer
data_augmentation = tf.keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.2),
    ]
)

inputs = tf.keras.Input(shape=(256, 256, 3))
base_model = data_augmentation(inputs)
base_model = VGG16(include_top=False, weights='imagenet', input_tensor=base_model)
x = base_model.output
x = layers.Flatten()(x)
x = layers.Dense(4096, activation="relu")(x)
x = layers.Dense(4096, activation="relu")(x)
outputs = layers.Dense(5, activation="softmax")(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 256, 256, 3)]     0         
_________________________________________________________________
sequential (Sequential)      (None, 256, 256, 3)       0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 256, 256, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 256, 256, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 128, 128, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 128, 128, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 128, 128, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 64, 64, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 64, 64, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 64, 64, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 64, 64, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 32, 32, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 32, 32, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 32, 32, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 32, 32, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 16, 16, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 16, 16, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 16, 16, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 16, 16, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 8, 8, 512)         0         
_________________________________________________________________
flatten (Flatten)            (None, 32768)             0         
_________________________________________________________________
dense (Dense)                (None, 4096)              134221824 
_________________________________________________________________
dense_1 (Dense)              (None, 4096)              16781312  
_________________________________________________________________
dense_2 (Dense)              (None, 5)                 20485     
=================================================================
Total params: 165,738,309
Trainable params: 165,738,309
Non-trainable params: 0
_________________________________________________________________

我們可以使用凍結(Freezing)的方式,來指定某些層不要進行訓練,被凍結的層權重就不會被更新到,例如將 VGG16 的卷積層和池化層的部分凍結,使用 ImageNet 的權重,即為一個預訓練模型,後面接的全連接層才是模型真正有在更新權重的部分。例如上述程式碼做遷移學習的方法,使用 layer.trainable 去設定哪些層是要訓練或不訓練的:

import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras import layers

# use data augmentation layer
data_augmentation = tf.keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.2),
    ]
)

inputs = tf.keras.Input(shape=(256, 256, 3))
base_model = data_augmentation(inputs)
base_model = VGG16(include_top=False, weights='imagenet', input_tensor=base_model)
x = base_model.output
x = layers.Flatten()(x)
x = layers.Dense(4096, activation="relu")(x)
x = layers.Dense(4096, activation="relu")(x)
outputs = layers.Dense(5, activation="softmax")(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

# freeze
for layer in base_model.layers:
    layer.trainable = False

此段程式碼設定了凍結層,凍結(不訓練、不更新權重) VGG16 的卷積層與池化層部分,執行 model.summary() 如下:

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 256, 256, 3)]     0         
_________________________________________________________________
sequential_1 (Sequential)    (None, 256, 256, 3)       0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 256, 256, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 256, 256, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 128, 128, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 128, 128, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 128, 128, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 64, 64, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 64, 64, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 64, 64, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 64, 64, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 32, 32, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 32, 32, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 32, 32, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 32, 32, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 16, 16, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 16, 16, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 16, 16, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 16, 16, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 8, 8, 512)         0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 32768)             0         
_________________________________________________________________
dense_3 (Dense)              (None, 4096)              134221824 
_________________________________________________________________
dense_4 (Dense)              (None, 4096)              16781312  
_________________________________________________________________
dense_5 (Dense)              (None, 5)                 20485     
=================================================================
Total params: 165,738,309
Trainable params: 151,023,621
Non-trainable params: 14,714,688
_________________________________________________________________

可以觀察到 Trainable params 和 Non-trainable params 的數量與之前不同了!Trainable params 的總數即為最後 3 層全連接層的參數總和,表示是真正有進行訓練的層。

微調模型

依照需求,我們也可以從 VGG16 後面幾層就不要凍結,設定它們是可以訓練的層,例如設定從第 15 層開始是真正要訓練的,第 15 層以前都是用原本 ImageNet 預訓練的參數:

import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras import layers

# use data augmentation layer
data_augmentation = tf.keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.2),
    ]
)

inputs = tf.keras.Input(shape=(256, 256, 3))
base_model = data_augmentation(inputs)
base_model = VGG16(include_top=False, weights='imagenet', input_tensor=base_model)
x = base_model.output
x = layers.Flatten()(x)
x = layers.Dense(4096, activation="relu")(x)
x = layers.Dense(4096, activation="relu")(x)
outputs = layers.Dense(5, activation="softmax")(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

# freeze all layers before the 15th layer
for layer in base_model.layers[:15]:
    layer.trainable = False

#  allow training for layers starting from the 15th layer
for layer in base_model.layers[15:]:
    layer.trainable = True

執行 model.summary() 結果:

Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_3 (InputLayer)         [(None, 256, 256, 3)]     0         
_________________________________________________________________
sequential_2 (Sequential)    (None, 256, 256, 3)       0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 256, 256, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 256, 256, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 128, 128, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 128, 128, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 128, 128, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 64, 64, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 64, 64, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 64, 64, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 64, 64, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 32, 32, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 32, 32, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 32, 32, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 32, 32, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 16, 16, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 16, 16, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 16, 16, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 16, 16, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 8, 8, 512)         0         
_________________________________________________________________
flatten_2 (Flatten)          (None, 32768)             0         
_________________________________________________________________
dense_6 (Dense)              (None, 4096)              134221824 
_________________________________________________________________
dense_7 (Dense)              (None, 4096)              16781312  
_________________________________________________________________
dense_8 (Dense)              (None, 5)                 20485     
=================================================================
Total params: 165,738,309
Trainable params: 158,103,045
Non-trainable params: 7,635,264
_________________________________________________________________

可以看到 Trainable params 為第 15 層開始的參數總和,即表示從第 15 層開始會真正進行訓練。

可以使用以下程式碼查看第 15 層的名稱:

print(base_model.layers[15].name)

執行結果:

block4_pool

補充 1:使用序列式模型建構

如果想用序列式模型建構上述遷移學習模型,也是可以的:

import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten

# use data augmentation layer
data_augmentation = tf.keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.2),
    ]
)

inputs = tf.keras.Input(shape=(256, 256, 3))
base_model = data_augmentation(inputs)
base_model = VGG16(include_top=False, weights='imagenet', input_tensor=base_model)
model = Sequential()
model.add(base_model)
model.add(Flatten())
model.add(Dense(4096, activation='relu'))
model.add(Dense(4096, activation='relu'))
model.add(Dense(5, activation='softmax'))

# freeze
base_model.trainable = False

模型內容:

Model: "sequential_7"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
vgg16 (Functional)           (None, 8, 8, 512)         14714688  
_________________________________________________________________
flatten_3 (Flatten)          (None, 32768)             0         
_________________________________________________________________
dense_9 (Dense)              (None, 4096)              134221824 
_________________________________________________________________
dense_10 (Dense)             (None, 4096)              16781312  
_________________________________________________________________
dense_11 (Dense)             (None, 5)                 20485     
=================================================================
Total params: 165,738,309
Trainable params: 151,023,621
Non-trainable params: 14,714,688
_________________________________________________________________

以及從第 15 層開始訓練:

import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten

# use data augmentation layer
data_augmentation = tf.keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.2),
    ]
)

inputs = tf.keras.Input(shape=(256, 256, 3))
base_model = data_augmentation(inputs)
base_model = VGG16(include_top=False, weights='imagenet', input_tensor=base_model)
model = Sequential()
model.add(base_model)
model.add(Flatten())
model.add(Dense(4096, activation='relu'))
model.add(Dense(4096, activation='relu'))
model.add(Dense(5, activation='softmax'))

# freeze all layers before the 15th layer
for layer in model.layers[0].layers[:15]:
    layer.trainable = False
Model: "sequential_9"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
vgg16 (Functional)           (None, 8, 8, 512)         14714688  
_________________________________________________________________
flatten_4 (Flatten)          (None, 32768)             0         
_________________________________________________________________
dense_12 (Dense)             (None, 4096)              134221824 
_________________________________________________________________
dense_13 (Dense)             (None, 4096)              16781312  
_________________________________________________________________
dense_14 (Dense)             (None, 5)                 20485     
=================================================================
Total params: 165,738,309
Trainable params: 158,103,045
Non-trainable params: 7,635,264
_________________________________________________________________

結果都是相同的。

補充 2:用函數式 API 改寫資料增強層

建構模型的方式是可以互相搭配使用的,這裡提供資料增強層改為函數式 API 的寫法:

# 例如輸入層為 inputs
x = layers.RandomFlip("horizontal")(inputs)
x = layers.RandomRotation(0.1)(x)
x = layers.RandomZoom(0.2)(x)

今天介紹遷移學習的相關概念,是不是很有趣呢?明天就要開始踩油門囉~~~/images/emoticon/emoticon08.gif


上一篇
[Day 10] 模型建構的方法 (2):函數式 API
下一篇
[Day 12] 訓練,啟動!開始編譯模型
系列文
輕鬆上手AI專案-影像分類到部署模型14
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言