iT邦幫忙

2022 iThome 鐵人賽

DAY 10
0
AI & Data

PyTorch 生態鏈實戰運用系列 第 10

[Day10] Pytorch Lightning

  • 分享至 

  • xImage
  •  

前言

本日將簡單介紹Pytorch-Lightning,而在包含今日的未來幾天內,會將先前構築的程式碼,分段整合成Pytorch-Lightning的格式。本日的部份是forward propogation的部份。

Pytorch Lightning 是什麼?

Pytorch Lightning是一個標榜同時可以簡化工程作業量,又同時具備高擴充性的Pytorch相容框架。
其與Pytorch的關係,有點類似TensorFlowKeras

(註:筆者本人也用過一陣子Keras跟TensorFlow 2,Keras的操作更加簡易,但是Flexibility就不太令人滿意,尤其是要客製一些框架或是訓練策略的時候,反倒是TensorFlow 2還順手一些)

主要的概念跟作法,可以直接參考下列這個來自Pytorch Lightning官方文件 LIGHTNING IN 15 MINUTES的簡介影片:

為什麼會需要 Pytorch Lightning?

簡單看完影片以後,相信大概能有個概念。現在來舉一個最簡單的例子,讓我們先來回顧前幾日的train.py裡頭每個epoch的training跟validation是怎麼做的?

Training Phase:

inputs, labels = batch['img'].to(device), batch['labels'].float().to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()

Validation Phase

for batch in pbar:
    step += 1
    val_images, val_labels =  batch['img'].to(device), batch['labels'].to(device)
    y_pred = torch.cat([y_pred, model(val_images)], dim=0)
    y = torch.cat([y, val_labels], dim=0)
    pbar.set_description('Validating ...')
y_prob = torch.nn.Sigmoid()(y_pred)
loss = loss_function(y_pred, y.float()).item()

有沒有發現驚人的重工之處?

基本上都是在forward propogation以後在計算loss,只是差在有沒有梯度下降的差異而已。而且其實重複的且類似但又不太一樣的程式碼,也是增加進行小修改時出錯的風險。

Lightning Module

Pytorch Lightning 裡最核心的api就屬 Lightning Module,只要把模型整合成這個物件,基本上就可以開啟 Pytorch Lightning內的各種強大的支援。

根據文件中的內容,可以透過這個api把上一段落內的training跟validation大致整成下面的架構如下:

import pytorch_lightning as pl

class MultiLabelsModel(pl.LightningModule):
    """
    Lightning Module of Multi-Labels Classification for ChestMNIST
    """
    def __init__(self, CONFIG):
        self.backbone = get_backbone(CONFIG)
        ... # 可以網羅各種的初始設定,通常我會把大部分的超參數
        ... # 還有一些實驗過程需要的額外物件放在這個地方

    def forward(self, x):
        y = self.backbone(x)  # model inference 的主體,使用很自由
        return y              # 不論是要加層,增加input或output都可以簡單實現

    def step(self, batch: Any):
        inputs, labels = batch['img'].to(self.device), batch['labels'].to(self.device)
        preds = self.forward(inputs)
        loss = self.loss_function(preds, labels.float())
        return inputs, preds, labels, loss

    def training_step(self, batch: Any, batch_idx: int):
        inputs, preds, labels, loss = self.step(batch)
        return loss

    def validation_step(self, batch: Any, batch_idx: int):
        inputs, preds, labels, loss = self.step(batch)
        return {
            'preds' : outputs,
            'labels' : labels
        }

    def validation_epoch_end(self, validation_step_outputs: List[Any]):
        preds = torch.cat([output['preds'] for output in validation_step_outputs], dim=0)
        labels = torch.cat([output['labels'] for output in validation_step_outputs], dim=0)
        probs = torch.nn.Sigmoid()(preds)
        
        # compute metrics and log
        acc_score = torchmetrics.functional.accuracy(probs, labels, mdmc_average = 'global')
        auc_score = monai.metrics.compute_roc_auc(probs, labels, average='macro')
    ...

透過呼叫共用的step,就可以讓分別對應的training_stepvalidation_step都能實現與原先相同的forward propogation。而要蒐集整個validation結果,進而計算accuracy與auc的部份,則可以在validation_epoch_end內,會自動將每個validation_step的output作為input輸入,就可以計算整個驗證集的指標了。

如此切割各個功能後,除了可讀性上比較好一些,要debug也會比較容易一些,可說是好處多多。後續還有許多設計檔的瑣碎工作需要做,就讓我們挪到後續幾天再來慢慢完成!

本日小節

  • 簡介 Pytorch Lightning
  • 介紹PyTorch的forward propogation整合成Lightning Module的形式

上一篇
[Day09] Deep Learning with Configuration
下一篇
[Day11] Build a trainable Lightning-Module
系列文
PyTorch 生態鏈實戰運用30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言