iT邦幫忙

第 12 屆 iThome 鐵人賽

DAY 25
0
AI & Data

AI從入門到放棄系列 第 25

Day 25 ~ AI從入門到放棄 - 遷移學習之四

  • 分享至 

  • xImage
  •  

我們創建一個簡單的Functional模型並畫出來。

from tensorflow.keras.layers import concatenate, Conv2D, Dense, GlobalAveragePooling2D, Input, MaxPooling2D
from tensorflow.keras.models import Model

inpt = Input(shape=(32, 32, 3))
conv1x1_1 = Conv2D(filters=64, kernel_size=(1, 1), padding='same', activation='relu')(inpt)
conv3x3_1 = Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation='relu')(inpt)
conv5x5_1 = Conv2D(filters=64, kernel_size=(5, 5), padding='same', activation='relu')(inpt)
x = concatenate([conv1x1_1, conv3x3_1, conv5x5_1], axis=3)
x = MaxPooling2D()(x)
conv1x1_2 = Conv2D(filters=64, kernel_size=(1, 1), padding='same', activation='relu')(x)
conv3x3_2 = Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation='relu')(x)
conv5x5_2 = Conv2D(filters=64, kernel_size=(5, 5), padding='same', activation='relu')(x)
x = concatenate([conv1x1_2, conv3x3_2, conv5x5_2], axis=3)
x = MaxPooling2D()(x)
x = Conv2D(filters=64, kernel_size=(1, 1), activation='relu')(x)
x = GlobalAveragePooling2D()(x)
x = Dense(units=10)(x)

model = Model(inputs=inpt, outputs=x)

先用summary觀察一下,可以發現模型層紀錄的順序跟我們搭建時一樣。

model.summary()
Model: "functional_1"
______________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                 
==============================================================================================
input_3 (InputLayer)            [(None, 32, 32, 3)]  0                                        
______________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 32, 32, 64)   256         input_3[0][0]                
______________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 32, 32, 64)   1792        input_3[0][0]                
______________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 32, 32, 64)   4864        input_3[0][0]                
______________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 32, 32, 192)  0           conv2d_10[0][0]              
                                                                 conv2d_11[0][0]              
                                                                 conv2d_12[0][0]              
______________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 16, 16, 192)  0           concatenate_3[0][0]          
______________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 16, 16, 64)   12352       max_pooling2d_2[0][0]        
______________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 16, 16, 64)   110656      max_pooling2d_2[0][0]        
______________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 16, 16, 64)   307264      max_pooling2d_2[0][0]        
______________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 16, 16, 192)  0           conv2d_13[0][0]              
                                                                 conv2d_14[0][0]              
                                                                 conv2d_15[0][0]              
______________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 8, 8, 192)    0           concatenate_4[0][0]          
______________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 8, 8, 64)     12352       max_pooling2d_3[0][0]        
______________________________________________________________________________________________
global_average_pooling2d_1 (Glo (None, 64)           0           conv2d_16[0][0]              
______________________________________________________________________________________________
dense_1 (Dense)                 (None, 10)           650         global_average_pooling2d_1[0][0] 
==============================================================================================
Total params: 450,186
Trainable params: 450,186
Non-trainable params: 0
______________________________________________________________________________________________

接著畫出圖來,如果你沒有將bin資料夾跟ipynb檔案放在一起,那請自行將bin/換成你的路徑,plot_model會建立一張model.png圖片出來。

import os
os.environ["PATH"] += os.pathsep + 'bin/'
from tensorflow.keras.utils import plot_model
plot_model(model, to_file='model.png')

https://ithelp.ithome.com.tw/upload/images/20200917/20129770BHqIvyQ6KQ.png
接著對前7層關閉訓練,可訓練的權重剩下,450186-0-256-1792-4864-0-0-12352=430922個。

for layer in model.layers[:7]:
  layer.trainable = False
model.summary()
Model: "functional_1"
______________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                 
==============================================================================================
input_1 (InputLayer)            [(None, 32, 32, 3)]  0                                        
______________________________________________________________________________________________
conv2d (Conv2D)                 (None, 32, 32, 64)   256         input_1[0][0]                
______________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 32, 32, 64)   1792        input_1[0][0]                
______________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 32, 32, 64)   4864        input_1[0][0]                
______________________________________________________________________________________________
concatenate (Concatenate)       (None, 32, 32, 192)  0           conv2d[0][0]                 
                                                                 conv2d_1[0][0]               
                                                                 conv2d_2[0][0]               
______________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 16, 16, 192)  0           concatenate[0][0]            
______________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 16, 16, 64)   12352       max_pooling2d[0][0]          
______________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 16, 16, 64)   110656      max_pooling2d[0][0]          
______________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 16, 16, 64)   307264      max_pooling2d[0][0]          
______________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 16, 16, 192)  0           conv2d_3[0][0]               
                                                                 conv2d_4[0][0]               
                                                                 conv2d_5[0][0]               
______________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 8, 8, 192)    0           concatenate_1[0][0]          
______________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 8, 8, 64)     12352       max_pooling2d_1[0][0]        
______________________________________________________________________________________________
global_average_pooling2d (Globa (None, 64)           0           conv2d_6[0][0]               
______________________________________________________________________________________________
dense (Dense)                   (None, 10)           650         global_average_pooling2d[0][0]
==============================================================================================
Total params: 450,186
Trainable params: 430,922
Non-trainable params: 19,264
______________________________________________________________________________________________

https://ithelp.ithome.com.tw/upload/images/20200917/201297709opzmw3zoT.png
想對Functional模型的層關閉訓練前可以畫圖出來看一下,盡量不要像圖中有一個關閉,其他兩個是開放訓練的狀態,在此例中可以關閉9層或6層,但是在大型模型中,比如之前的EfficientNetB0有非常多層,可能不易觀察。


上一篇
Day 24 ~ AI從入門到放棄 - 遷移學習之三
下一篇
Day 26 ~ AI從入門到放棄 - 貓狗辨識之一
系列文
AI從入門到放棄30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言