iT邦幫忙

2021 iThome 鐵人賽

DAY 9
1
AI & Data

手寫中文字之影像辨識系列 第 9

【第9天】訓練模型-遷移學習

  • 分享至 

  • xImage
  •  

摘要

  1. 遷移學習說明
  2. 遷移學習類型
  3. 淺談預訓練與微調
  4. 如何進行遷移學習

內容

  1. 說明:基於資料集(ImageNet分類包括蛇、蜥蜴)、任務(皆為圖片分類)相似性,將預訓練的模型應用在新資料集的學習過程。通常用來解決以下問題:

    1.1 處理大量未標記資料(如:利用預訓練的vgg16模型,辨識貓狗照片,並進行標籤)

    1.2 降低大量資料的訓練成本:大量或性質相似的資料集,以遷移學習提高訓練效率。(節省時間、硬體資源)

    1.3 醫療應用需求(如:腎臟病變切片樣本或有標記的樣本稀少)

  2. 遷移學習類型

    2.1 基於實例:以ImageNet常見的1000種分類為例。

    • 1000種分類中有鳥、蜥蜴、猴子、蛇...等動物(domain)。
    • 若任務是辨識圖片中是蛇與蜥蜴(task),可手動調整蛇和蜥蜴的權重。(依照經驗調整)

    2.2 基於特徵:

    • 特徵萃取:若任務是分辨10種不同的蛇(task),先以預訓練模型(domain)對10種蛇做特徵萃取,再將特徵餵入自己定義的神經網路訓練。
    # 特徵萃取
    def feature_extraction_InV3(img_width, img_height,
                         train_data_dir,
                         num_image,
                         epochs):
        base_model = InceptionV3(input_shape=(299, 299, 3),
                              weights='imagenet', include_top=False)
        x = base_model.output
        x = GlobalAveragePooling2D()(x)
    
        model = Model(inputs=base_model.input, outputs=x)
    
        train_generator = ImageDataGenerator(rescale=1. / 255).flow_from_directory(train_data_dir,
        target_size=(299, 299),
        # 每次進來訓練並更新學習率的圖片數 -> if 出現 memory leak -> 調低此參數
        batch_size=18,
        class_mode="categorical",
        shuffle=False)
    
        y_train=train_generator.classes
        # 依據class數量而定, np.zeros -> 宣告全部為0的空陣列
        y_train1 = np.zeros((num_image, 5))
        # np.arrange打標籤
        y_train1[np.arange(num_image), y_train] = 1
    
        # 重設generator
        train_generator.reset
        X_train=model.predict_generator(train_generator, verbose=1)
        print(X_train.shape, y_train1.shape)
        return X_train, y_train1, model
    
    # 自定義全連接層
    def train_last_layer(img_width, img_height,
                         train_data_dir,
                         num_image,
                         epochs):
        # 處理train資料夾
        X_train, y_train, model=feature_extraction_InV3(img_width, img_height,
                             train_data_dir,
                             num_image,
                             epochs)
    
        # 處理test資料夾
        X_test,y_test,model=feature_extraction_InV3(img_width,img_height,
                             test_data_dir,
                             num_test_image,
                             epochs)
    
        my_model = Sequential()
        my_model.add(BatchNormalization(input_shape=X_train.shape[1:]))
        my_model.add(Dense(1024, activation="relu"))
        my_model.add(Dense(5, activation='softmax'))
        my_model.compile(optimizer="SGD", loss='categorical_crossentropy',metrics=['accuracy'])
        print(my_model.summary())
    
        history = my_model.fit(X_train, y_train, epochs=20,
                  validation_data=(X_test, y_test),
                  batch_size=30, verbose=1)
        my_model.save('model_CnnModelTrainWorkout_v3_5calsses.h5')
        return history
    

    2.3 基於模型:task與domain參數共享。如:預訓練模型僅修改輸出層(分2類),載入權重進行參數遷移學習。

    2.4 基於關係(參考許多文章與書籍,這個類別還是無法理解。歡迎有研究的夥伴,留言分享。)

  3. 預訓練與微調

    3.1 預訓練:訓練模型時,從一開始的隨機初始化參數,到隨著訓練調整參數,完成訓練並儲存參數的過程。

    3.2 微調:將預訓練獲得的參數,作為新資料集訓練模型的初始參數,訓練後獲得適應新資料集的模型。

  4. 遷移學習過程

    4.1 挑選預訓練模型:以tensorflow框架為例,可至Keras Application,挑選準確度高、輕量化的預訓練模型,逐一訓練、比較。

    4.2 選擇訓練方式

    • <資料集大,目標域(task)相似於來源域(domain)>

      載入模型結構,將預訓練權重當作初始化參數,凍結底層卷積層,僅訓練部分卷積層與頂端層。(部分凍結)

    • <資料集大,目標域(task)不同於來源域(domain)>

      載入模型結構,將預訓練權重當作初始化參數,以新資料集全部重新訓練。(不凍結)

    • <資料集小,目標域(task)相似於來源域(domain)>

      載入模型結構與權重,僅修改輸出層(如:分1000類改成3類)或整個全連接層,再進行模型訓練。

    • <資料集小,目標域(task)不同於來源域(domain)>

      以資料擴增增加訓練樣本,後續步驟同<資料集大,目標域(task)不同於來源域(domain)>。

    ※註:通常每個類別的資料數量小於1000筆,視為小資料集。


小結

  1. 資料前處理後,新資料集約有19.3萬張。其中,每個中文字約有80-300張圖檔,且中文字不在1000個類別內。故屬於「資料集小,目標域不同於來源域」。
  2. 下一章,目標是:「介紹Tensorflow Keras Application,並挑選預訓練模型」。

讓我們繼續看下去...


參考資料

  1. 深度學習不得不會的遷移學習Transfer Learning
  2. 迁移学习(Transfer),面试看这些就够了!
  3. 第十一章 迁移学习

上一篇
【第8天】訓練模型-CNN與訓練方向
下一篇
【第10天】訓練模型-預訓練模型
系列文
手寫中文字之影像辨識31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言