iT邦幫忙

2022 iThome 鐵人賽

DAY 11
0
AI & Data

PyTorch 生態鏈實戰運用系列 第 11

[Day11] Build a trainable Lightning-Module

  • 分享至 

  • xImage
  •  

前言

接續前一日的文章,我們簡介了PyTorch Lightning以及如何利用Lightning Module將forward propagation進行簡單的封裝。今天我們將把Lightning Module完成,並利用新的格式進行模型訓練。

Lightning Module with Hyperparameters

先前已經利用YAML的方式進行超參數管理(詳見hparam.yaml),進一步,這裡可以把整個Configuration當作是整個Lightning Module的輸入,直接在模組內搭建模型的超參數。

module initialization

# initialize
def __init__(self, CONFIG : Dict, **kwargs):
    super().__init__()
    self.CONFIG = CONFIG # 把整個CONFIG放到CLASS內的一個property,以利其他函式可以重複利用。
    self.save_hyperparameters() # PL內建的儲存當次實驗超參數的函數
    self.backbone = get_backbone(CONFIG) # 用之前創立的函數,建立backbone
    self.loss_function = getattr(torch.nn, CONFIG['train']['loss_function'])() # 利用CONFIG設定需要的loss function

optimizer

def configure_optimizers(self):
    opt = getattr(torch.optim, self.CONFIG['train']['optimizer'])
    opt = opt(params=self.parameters(), 
              lr = self.CONFIG['train']['learning_rate'])
    return opt

這些都建立好以後,初步的一個lightning module就完成啦!

詳細可以看一下model.py,一樣可以使用下列指令來進行測試整個模型的建立狀況:

python src/model.py --config=hparams.yaml 

Trainer 與 Callback

在建立好Lightning-Module以後, 下一個問題便是怎麼使用它來訓練以及儲存結果了。這裡介紹兩個最常用的功能,分別是trainer以及callback。

在先前的PyTorch實作當中,我們使用了每個EPOCH訓練後,驗證集的AUROC來作為挑選模型的標準,進而儲存要選擇的模型。在PL當中,我們會利用一個Callback的物件,在模型過程中紀錄各式各樣的事情,例如下列的ModelCheckpoint就可以用來儲存模型的結果:

checkpoint_callback = pl.callbacks.ModelCheckpoint(dirpath=CONFIG['train']['weights_folder'],
                                                   monitor= 'val/auroc',
                                                   mode='max',
                                                   save_top_k=3,
                                                   filename = 'epoch_{epoch:02d}_val_loss_{val/loss:.2f}_val_acc_{val/acc:.2f}_val_auroc_{val/auroc:.2f}',
                                                   auto_insert_metric_name = False)

這邊用白話來講就是,利用指定的名稱,根據validation set的auroc,去儲存最高的三個權重組到指定的資料夾內。另外是為了在訓練途中,callback能夠認識到val\auroc,我們需要利用Lightning-Module本身的log功能,在訓練的時紀錄所需要的metrics:

def validation_epoch_end(self, validation_step_outputs: List[Any]):
    preds = torch.cat([output['preds'] for output in validation_step_outputs], dim=0)
    labels = torch.cat([output['labels'] for output in validation_step_outputs], dim=0)
    probs = torch.nn.Sigmoid()(preds)

    # compute metrics and log
    acc_score = torchmetrics.functional.accuracy(probs, labels, mdmc_average = 'global')
    auc_score = monai.metrics.compute_roc_auc(probs, labels, average='macro')
    self.log('val/acc', acc_score.item())
    self.log('val/auroc', auc_score.item())

最後是設置Trainer,以及要訓練的模型與對應的相關參數:

from src import model

net = model.MultiLabelsModel(CONFIG)
trainer = pl.Trainer(
        callbacks = checkpoint_callback,
        default_root_dir = CONFIG['train']['weights_folder'],
        max_epochs = CONFIG['train']['max_epochs'],
        limit_train_batches = CONFIG['train']['steps_in_epoch'],
        accelerator = 'cuda',
        devices = 1)
        
# model training
trainer.fit(net, 
            data_generators['TRAIN'], 
            data_generators['VALIDATION'])

接著就可以開始進行模型訓練了!

實作

一樣這次的實作有放到這個commit內,只要執行:

# python src/train.py --config=hparams.yaml
...
Epoch 24: 100%|█████████████████████████████████████████████| 138/138 [00:40<00:00,  3.42it/s, loss=0.194, v_num=16]

可以看到相同的參數底下,pl所訓練+驗證一個epoch的時間約是40秒。這項與先前在DAY08所實作的PyTorch訓練約30秒及驗證10秒加起來的數值相近,可見目前為止的效率而言是幾乎一樣的!

另一方面,以這次筆者實際訓練的結果為例,可以看到Callback所儲存的結果如下圖:

也與DAY08所得到的模型結果,auroc=0.58十分接近!

本日小節

  • 完成Lightning-Module
  • 使用checkpoint以及trainer完成訓練

上一篇
[Day10] Pytorch Lightning
下一篇
[Day12] Training Log and History
系列文
PyTorch 生態鏈實戰運用30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言