iT邦幫忙

2022 iThome 鐵人賽

DAY 8
0
AI & Data

PyTorch 生態鏈實戰運用系列 第 8

[Day08] Model Validation with PyTorch

  • 分享至 

  • xImage
  •  

前言

前一日已經開始進行模型的訓練。本日將討論要如何確認或挑選訓練出來的模型是否真的好?真的朝著正確的方向在邁進呢?

過擬合 (Overfitting)

在訓練的過程當中,很多情況只要是Training Code沒有異常的bugs的情況底下,在訓練集上的loss通常只會不斷下降。

這是否表示,我訓練出來的模型正在不斷變好呢? 這個答案,你知、我知、獨眼龍也知道,當然是! 主因是在機械學習越來越發展以來,模型內的參數也隨之越來越多。在這種情況底下,模型很有可能會去把某些樣本或是巧合硬是記下來,換句人話就是模型把答案硬背下來了!(尤其是在Deep Learning時代下,參數又多,訓練集的樣本又每張都看過幾十幾百次的狀況下,更加的容易發生。)

上圖則是參考Wiki Overfitting條目當中的圖,其中綠色的線就是想表達一個過擬合的圖。

雖然它完全正確分出紅色藍色了,但我們事實上很害怕這樣的模型,如同上面所說的,它只是用極端的狀況去硬記訓練資料,進而在實際上無法套用到新的或實際的資料上。

驗證集(Validation Set)

為了去驗證我們的模型套用沒有學習過的資料時的效果,一般我們會保留一份資料用來檢查模型表現,這個資料子集通常我們就稱作為驗證集。(另外還有測試集,後續會再介紹。)

常見的具體實作,通常我們每次訓練模型到一個段落的時候,會使用當下的模型針對驗證集內的所有資料進行推論,並紀錄當下的模型在驗證集上各種metric的表現,進而評估模型的好壞。

以我們的Multi-Label Classification來說,我們最主要就是比較準確率(Accuracy)以及AUROC(Area Under the Receiver Operating Characteristic),這部份在TorchmetricsMONAI上都可以找到對應的函數可以使用。

本次的具體實作可以在每一個epoch的後面加上:

model.eval()
with torch.no_grad():
    y_pred = torch.tensor([], dtype=torch.float32, device=device)
    y = torch.tensor([], dtype=torch.long, device=device)
    pbar = tqdm.tqdm(data_generators['VALIDATION'], total = len(processed_datasets['VALIDATION']) // data_generators['VALIDATION'].batch_size)
    for batch in pbar:
        val_images, val_labels =  batch['img'].to(device), batch['labels'].to(device)
        y_pred = torch.cat([y_pred, model(val_images)], dim=0)
        y = torch.cat([y, val_labels], dim=0)
        pbar.set_description('Validating ...')
    y_prob = torch.nn.Sigmoid()(y_pred)
    loss = loss_function(y_pred, y.float()).item()
    acc_score = torchmetrics.functional.accuracy(y_prob, y, mdmc_average = 'global').item()
    auc_score = monai.metrics.compute_roc_auc(y_prob, y, average='macro').item()

這裡要注意幾個重要的小細節,分別是

  • model.eval():做推論的模式切換,沒有做的話像是Dropout或是Batch Normalization就會根據訓練的模式跑出不正確的結果。
  • with torch.no_grad():使用沒有梯度的模式進行運算,節省運算資源。

實作

新增了Validation以後的實作一樣放在Github對應的commit內,簡單執行,等待一下就可以得到結果:

# python src/train.py
----------
----------
epoch 24/25
Training Epoch 50/50train_loss: 0.2043:  98%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▋  | 49/50 [00:31<00:00,  1.57it/s]
epoch 24 average loss: 0.1974
Validating ...: : 88it [00:10,  8.11it/s]                                                                                                                                                     
current epoch: 24 current loss : 0.1849 current AUC: 0.5837 current accuracy: 0.9489 best AUC: 0.5849 at epoch: 23
----------
epoch 25/25
Training Epoch 50/50train_loss: 0.2136:  98%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▋  | 49/50 [00:29<00:00,  1.68it/s]
epoch 25 average loss: 0.1972
Validating ...: : 88it [00:11,  7.84it/s]                                                                                                                                                     
saved new best metric model
current epoch: 25 current loss : 0.1843 current AUC: 0.5866 current accuracy: 0.9490 best AUC: 0.5866 at epoch: 25
train completed, best_metric(AUC): 0.5866at epoch: 25

這裡可以注意到幾點分別為:

  • 最後的training loss / validation loss 分別是 0.2136 與 0.1972 算是還不錯,代表兩者都很接近,到第25個epoch都還沒有什麼嚴重的overfitting發生。
  • 準確率來到 0.9490,比起MedMNIST網站上還要高
  • 但是AUROC的部份,僅0.5866,跟MedMNIST上的0.778相差甚遠
  • 綜合上述兩點,大概可以了解到,目前模型大概就是偏保守的模型,大部分都預測0,所以準確率很高,但實際要抓到病徵的能力還遠遠不夠。

本日小節

  • 簡介Overfitting的概念
  • 加入驗證集進入訓練的環節
  • 模型有待加強!

上一篇
[Day07] Model Training with PyTorch
下一篇
[Day09] Deep Learning with Configuration
系列文
PyTorch 生態鏈實戰運用30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言