iT邦幫忙

2024 iThome 鐵人賽

DAY 9
0
自我挑戰組

菜鳥AI工程師給碩班學弟妹的挑戰系列 第 9

[Day9] pytorch lightning (實作) - 2

  • 分享至 

  • xImage
  •  

前情提要: 昨天已經完成model.py, dataloader.py, 我自己習慣分成這兩個檔案,因為等到之後model的code越來越多,已經不適合跟train_step這些寫在一起。

1. main.py 基礎元素

這裡一樣有幾項固定的東東:
這些名稱都是固定的哦

  1. forward, training_step, validation_step: batch跟batch_idx名稱也固定,這裡的batch就是dataloader回傳的一個batch,可以把它想像成在訓練的時候他會自己透過dataloader取一個batch,接下來就要寫送到model裡面,再來計算loss。
  2. configure_optimizers: 設定optimizer的地方,通常會看模型而選用特定的optimizer,常用的是SGD, adam, adamW。
  3. train_dataloader, val_dataloader: 這裡就直接回傳DataLoader,如果想要有多個dataloader,也可以回傳多個,這裡我們就先簡單一點。
from torch.utils.data import DataLoader

import lightning as pl

from model import MNISTClassifier
from dataloader import CustomDataset

class example(pl.LightningModule):
    def __init__(
            self, 
            batch_size = 16,
            train_txt = "/ws/code/Day8/train.txt",
            val_txt = "/ws/code/Day8/test.txt",
        ):
        super().__init__()
        self.batch_size = batch_size

        self.train_dataset = CustomDataset(train_txt)
        self.val_dataset = CustomDataset(val_txt)

        self.model = MNISTClassifier()


    def forward(self, batch):
        pass

    def training_step(self, batch, batch_idx):
        pass

    def validation_step(self, batch, batch_idx):
        pass
    
    def configure_optimizers(self):
        pass

    def train_dataloader(self):

        return DataLoader(
            self.train_dataset, 
            batch_size = self.batch_size, 
            shuffle = True, 
            drop_last = True, 
            num_workers = 4,
        )

    def val_dataloader(self):
        
        return DataLoader(
            self.val_dataset, 
            batch_size = self.batch_size, 
            shuffle = False, 
            drop_last = True, 
            num_workers = 4,
        )
        

2. training_step

這裡batch就要回到我們之前的dataloader囉,我們回傳兩個東西,image, label,那我自己在step裡面會x, y表示,比較簡單。
這就是我之前所說的會寫成一個個function,簡單明瞭,然後你會發現好像少了pytorch常寫的loss.backward()…,主要是人家training_step都幫你包好好了,會自己去做backward等等,當然也可以手動更新。
流程如下:

  1. batch 分為 x(輸入), y(答案)
  2. 送到 model
  3. 計算loss
  4. 將loss的值寫到self.log,這個log就可以用tensorboard來圖示化,看訓練曲線也是很重要的!!
  5. return loss
    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self.model(x)
        loss = self.loss_fn(preds, y).mean()
        self.log("train_loss", loss.item(), prog_bar = True)

        return loss

3. validation_step

基本上前半段跟training_step一樣,那在這邊我自己習慣多metrics,也就是來評估模型訓練好不好,以下是各類評估常用到的,可以根據自己的任務選擇。

  • 分類最常用的幾個: accuracy, Precision, Recall, F1
  • ASR常用: WER, CER
  • speech enhancement: PESQ, STOI
  • 影像切割: IOU
  • 文字: ROUGE

這裡我們就選最簡單accuracy來實作,我們來透過使用torchmetrics這個包來計算吧~~
這裡的torchmetrics寫法有兩種,我選擇用第二種來做,也就是在validation_step update,on_validation_epoch_end來compute,可以把它想像成validation每跑一個batch就透過update更新,等到整個epoch跑完透過compute算出最後結果,然後紀錄在log並且reset用於下一個epoch計算。

from torch.utils.data import DataLoader
import torchmetrics
import lightning as pl

from model import MNISTClassifier
from dataloader import CustomDataset

class example(pl.LightningModule):
    def __init__(
            self, 
            batch_size = 16,
            train_txt = "/ws/code/Day8/train.txt",
            val_txt = "/ws/code/Day8/test.txt",
        ):
        super().__init__()
        self.batch_size = batch_size

        self.train_dataset = CustomDataset(train_txt)
        self.val_dataset = CustomDataset(val_txt)

        self.model = MNISTClassifier()
        self.valid_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes = 10)

    def validation_step(self, batch, batch_idx):
        x, y = batch
        preds = self.model(x)
        loss = self.loss_fn(preds, y).mean()
        self.log("val/loss", loss.item(), prog_bar = True)

        self.valid_acc.update(preds, y)

    def on_validation_epoch_end(self):
        self.log('valid_acc_epoch', self.valid_acc.compute())
        self.valid_acc.reset()

4. configure_optimizers

這邊就是設定optimizer跟lr_scheduler,簡單一點的就是固定learning rate,也可以透過scheduler來調整。
這裡的self.parameters()會去抓所有可更新的參數,在這邊就是self.model。

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr = 1e-3)
        # lr_scheduler 

        return optimizer # [optimizer], [lr_scheduler]

今天就更新到這囉,可以消化一下。
明天把後續更新完


上一篇
[Day8] pytorch lightning介紹 - 1
下一篇
[Day10] pytorch lightning 實作 - 3
系列文
菜鳥AI工程師給碩班學弟妹的挑戰30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言