iT邦幫忙

2021 iThome 鐵人賽

DAY 14
0
AI & Data

手寫中文字之影像辨識系列 第 14

【第14天】訓練模型-Xception

摘要

  1. Xception

    1.1 來源
    1.2 架構
    1.3 特性

  2. 訓練過程

    2.1 預訓練模型
    2.2 設置Callbacks
    2.3 設置訓練集
    2.4 開始訓練模型
    2.5 儲存模型與紀錄學習曲線

  3. 模型訓練結果

    3.1 學習曲線
    3.2 Accuracy與Loss

  4. 驗證模型準確度

    4.1 程式碼
    4.2 驗證結果


內容

  1. Xception

    1.1 來源:

    • 簡介:改良InceptionV3的Inception module,並引入depthwise separable convolution概念。
    • 時程:於2016年提出論文,並收錄於2017年的CVPR。
    • 論文名稱:Xception:Deep Learning with Depthwise Separable Convolutions

    1.2 架構

    • 以改良後的Extreme Inception取代InceptionV3的Inception module。(對照圖如下)

      • Extreme Inception(Xception)

      • Inception module(InceptionV3)

    • Extreme Inception引進Depthwise separable convolution概念降低網路的複雜度,同時拓寬網路,維持接近Inception module的參數量。

    • Standard convolution、Depthwise separable convolution與Extreme Inception。

      • Standard convolution(InceptionV3)

      • Depthwise separable convolution(MobileNets):將傳統卷積拆分成兩個步驟,在維持準確度的前提下,降低參數量與模型訓練時間。


        Depthwise → (b)channel卷積運算; Pointwise → (c)1x1卷積運算combining

      ※ 詳細請參考 論文中3.1Depthwise Separable Convolution

      • Extreme Inception(Xception):類似Depthwise separable convolution,只是兩者卷積運算的順序相反。

    1.3 特性

    • 觀察InceptionV3、Xception模型參數量,Xception僅略少,訓練時間稍微縮短。

    • Xception將空間相關性與通道相關性分離,更有效率的利用參數,模型準確度較高。

  2. 訓練過程:

    2.1 預訓練模型

    # IMPORT MODULES
    from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
    from keras.layers import Input, Dense, GlobalAveragePooling2D
    from keras.preprocessing.image import ImageDataGenerator
    from keras.optimizers import Adam
    import matplotlib.pyplot as plt
    from keras.models import Model
    from keras.applications import Xception
    
    # -----------------------------1.客製化模型--------------------------------
    # 載入keras模型(更換輸出圖片尺寸)
    model = Xception(include_top=False,
                     weights='imagenet',
                     input_tensor=Input(shape=(80, 80, 3))
                     )
    
    # 定義輸出層
    x = model.output
    x = GlobalAveragePooling2D()(x)
    predictions = Dense(800, activation='softmax')(x)
    model = Model(inputs=model.input, outputs=predictions)
    
    # 編譯模型
    model.compile(optimizer=Adam(lr=0.001),
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    

    2.2 設置Callbacks

    # -----------------------------2.設置callbacks-----------------------------
    # 設定earlystop條件
    estop = EarlyStopping(monitor='val_loss', patience=10, mode='min', verbose=1)
    
    # 設定模型儲存條件
    checkpoint = ModelCheckpoint('Xception_checkpoint_v2.h5', verbose=1,
                              monitor='val_loss', save_best_only=True,
                              mode='min')
    
    # 設定lr降低條件(0.001 → 0.0005 → 0.000125 → 0.0001)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5,
                               patience=5, mode='min', verbose=1,
                               min_lr=1e-4)
    

    2.3 設置訓練集

    # -----------------------------3.設置資料集--------------------------------
    # 設定ImageDataGenerator參數(路徑、批量、圖片尺寸)
    train_dir = './workout/train/'
    valid_dir = './workout/val/'
    test_dir = './workout/test/'
    batch_size = 32
    target_size = (80, 80)
    
    # 設定批量生成器
    train_datagen = ImageDataGenerator(rescale=1./255, 
                                       rotation_range=20,
                                       width_shift_range=0.2,
                                       height_shift_range=0.2,
                                       shear_range=0.2, 
                                       zoom_range=0.5,
                                       fill_mode="nearest")
    
    val_datagen = ImageDataGenerator(rescale=1./255)
    
    test_datagen = ImageDataGenerator(rescale=1./255)
    
    # 讀取資料集+批量生成器,產生每epoch訓練樣本
    train_generator = train_datagen.flow_from_directory(train_dir,
                                          target_size=target_size,
                                          batch_size=batch_size)
    
    valid_generator = val_datagen.flow_from_directory(valid_dir,
                                          target_size=target_size,
                                          batch_size=batch_size)
    
    test_generator = test_datagen.flow_from_directory(test_dir,
                                          target_size=target_size,
                                          batch_size=batch_size,
                                          shuffle=False)
    

    2.4 重新訓練模型權重

    # -----------------------------4.開始訓練模型------------------------------
    # 重新訓練權重
    history = model.fit_generator(train_generator,
                       epochs=50, verbose=1,
                       steps_per_epoch=train_generator.samples//batch_size,
                       validation_data=valid_generator,
                       validation_steps=valid_generator.samples//batch_size,
                       callbacks=[checkpoint, estop, reduce_lr])
    

    2.5 儲存模型與紀錄學習曲線

    # -----------------------5.儲存模型、紀錄學習曲線------------------------
    # 儲存模型
    model.save('./Xception_retrained_v2.h5')
    print('已儲存Xception_retrained_v2.h5')
    
    # 畫出acc學習曲線
    acc = history.history['accuracy']
    epochs = range(1, len(acc) + 1)
    val_acc = history.history['val_accuracy']
    plt.plot(epochs, acc, 'bo', label='Training acc')
    plt.plot(epochs, val_acc, 'r', label='Validation acc')
    plt.title('Training and validation accuracy')
    plt.legend(loc='lower right')
    plt.grid()
    # 儲存acc學習曲線
    plt.savefig('./acc.png')
    plt.show()
    
    # 畫出loss學習曲線
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    plt.plot(epochs, loss, 'bo', label='Training loss')
    plt.plot(epochs, val_loss, 'r', label='Validation loss')
    plt.title('Training and validation loss')
    plt.legend(loc='upper right')
    plt.grid()
    # 儲存loss學習曲線
    plt.savefig('loss.png')
    plt.show()
    
  3. 模型訓練結果

    3.1 訓練epochs:26 epochs。

    3.2 耗費時間:3小時29分24秒(12564秒)。

    3.3 學習曲線

    3.4 Accuary與Loss

  4. 驗證準確度

    4.1 程式碼

    # -------------------------6.驗證模型準確度--------------------------
    # 以vali資料夾驗證模型準確度
    test_loss, test_acc = model.evaluate_generator(test_generator,
                                steps=test_generator.samples//batch_size,
                                verbose=1)
    print('test acc:', test_acc)
    print('test loss:', test_loss)
    

    4.2 驗證結果


小結

下一章目標是:介紹第二個預訓練模型ResNet152V2,與分享訓練成果」。

讓我們繼續看下去...


參考資料

  1. Xception: Deep Learning with Depthwise Separable Convolutions
  2. MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications

上一篇
【第13天】訓練模型-優化器(Optimizer)
下一篇
【第15天】訓練模型-ResNet152V2
系列文
手寫中文字之影像辨識31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言