iT邦幫忙

2024 iThome 鐵人賽

DAY 10
0
自我挑戰組

菜鳥AI工程師給碩班學弟妹的挑戰系列 第 10

[Day10] pytorch lightning 實作 - 3

  • 分享至 

  • xImage
  •  

前情提要: 昨天基本上已經把大部分的code都寫完了,應該可以感受到lightning的簡潔。

1. loss_fn

在training_step跟validation_step有用到self.loss_fn主要就是loss function,如果是簡單的直接呼叫nn裡面的,那就寫成一行就好,但如果是比較複雜的,我自己會再創一個loss_function.py,然後再用import的方式帶進來。

self.loss_fn = nn.CrossEntropyLoss() #寫在__init__

2. main

以下一些設定會在明天做說明,可以先照抄,讓我們先把訓練跑起來。

def main(task, max_epochs):
    model = example()

    if task == 'train':
        callbacks = []
        dirpath = "./model"
        checkpoint_acc = ModelCheckpoint(
            save_top_k = 5,
            monitor = "valid_acc_epoch",
            mode = "max",
            dirpath = dirpath,
            filename = "model_{epoch:02d}_{valid_acc_epoch:.2f}",
            save_last = True,
        )
        callbacks.append(checkpoint_acc)
        trainer = pl.Trainer(
            max_epochs = max_epochs, 
            callbacks = callbacks, 
            gradient_clip_val = 2.0,
            # devices = [0, 1] # 多GPU訓練
        )  
        trainer.fit(model=model)
    # elif task == 'ft':

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description = "Training script")
    parser.add_argument('--task', type = str, default = 'train')
    parser.add_argument('--max_epochs', type = int, default = 50)
    
    args = parser.parse_args()
    main(args.task, args.max_epochs)

3. 開始訓練

就下來就可以執行囉,自己在前幾篇有筆誤,Dataset的__len__,以及在model當中多加softmax,如果你是從前兩天一路慢慢打程式的,深感抱歉,目前已經更新,其實很多bug是只有在run的時候才會知道,有時候是手殘,有時是眼殘,但也有蠻多是邏輯有錯,有些東西是沒有考慮到的,這些就只能慢慢debug增加實力。

這裡可以看到,你執行的時候會列印出哪一些裝置可取得,model總共多少參數,有時候實作論文的時候,看參數對不對十分重要!!
https://ithelp.ithome.com.tw/upload/images/20240814/20168446e3gHL51ohD.png

然後在你的目錄底下多了一個lightning_logs,這個就是我一直說的tensorboard,我們進到此目錄,輸入以下指令,然後你就可以開啟chrome輸入網址: :,後面port看你當初起container前面的port,如果你是用自己電腦,ip為127.0.0.1,port是6066

tensorboard --logdir . --port 6066 --bind_all

https://ithelp.ithome.com.tw/upload/images/20240814/20168446Fly3jcq0s9.png

當中的not found可以忽略不看,可以看到最下會有TensorBoard 版本 at,就代表有起成功。
https://ithelp.ithome.com.tw/upload/images/20240814/20168446bpvppVuQTo.png

4. tensorboard

可以看到訓練的曲線以及acc,那至於有沒有符合預期,我們明天繼續講~
https://ithelp.ithome.com.tw/upload/images/20240814/20168446MTX4EoXbIh.pnghttps://ithelp.ithome.com.tw/upload/images/20240814/20168446B07LIlmJ5k.png

整個main.py

可以再把每個過程看一下,檢查看看你是否懂這塊了。

import argparse

from torch.utils.data import DataLoader
import torchmetrics
import lightning as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import torch
import torch.nn as nn

from model import MNISTClassifier
from dataloader import CustomDataset

class example(pl.LightningModule):
    def __init__(
            self, 
            batch_size = 128,
            train_txt = "/ws/code/Day8/train.txt",
            val_txt = "/ws/code/Day8/test.txt",
        ):
        super().__init__()
        self.batch_size = batch_size

        self.train_dataset = CustomDataset(train_txt)
        self.val_dataset = CustomDataset(val_txt)

        self.model = MNISTClassifier()
        self.valid_acc = torchmetrics.classification.Accuracy(task = "multiclass", num_classes = 10)

        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, batch):
        pass

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self.model(x)
        loss = self.loss_fn(preds, y).mean()
        self.log("train/loss", loss.item(), prog_bar = True)

        return loss


    def validation_step(self, batch, batch_idx):
        x, y = batch
        preds = self.model(x)
        loss = self.loss_fn(preds, y).mean()
        self.log("val/loss", loss.item(), prog_bar = True)

        self.valid_acc.update(preds, y)

    def on_validation_epoch_end(self):
        self.log('valid_acc_epoch', self.valid_acc.compute())
        self.valid_acc.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr = 1e-3)
        # lr_scheduler 

        return optimizer # [optimizer], [lr_scheduler]

    def train_dataloader(self):

        return DataLoader(
            self.train_dataset, 
            batch_size = self.batch_size, 
            shuffle = True, 
            drop_last = True, 
            num_workers = 4,
        )

    def val_dataloader(self):
        
        return DataLoader(
            self.val_dataset, 
            batch_size = self.batch_size, 
            shuffle = False, 
            drop_last = True, 
            num_workers = 4,
        )
    


def main(task, max_epochs):
    model = example()

    if task == 'train':
        callbacks = []
        dirpath = "./model"
        checkpoint_acc = ModelCheckpoint(
            save_top_k = 5,
            monitor = "valid_acc_epoch",
            mode = "max",
            dirpath = dirpath,
            filename = "model_{epoch:02d}_{valid_acc_epoch:.2f}",
            save_last = True,
        )
        callbacks.append(checkpoint_acc)
        trainer = pl.Trainer(
            max_epochs = max_epochs, 
            callbacks = callbacks, 
            gradient_clip_val = 2.0,
            # devices = [0, 1] # 多GPU訓練
        )  
        trainer.fit(model=model)
    # elif task == 'ft':

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description = "Training script")
    parser.add_argument('--task', type = str, default = 'train')
    parser.add_argument('--max_epochs', type = int, default = 50)
    
    args = parser.parse_args()
    main(args.task, args.max_epochs)

今天就更新到這囉~ 花了很多時間才把最簡單的範例搞定,不過整個架構基本上差不多就這樣了,學會了基礎,接下來就可以往你想要研究的方向去寫一套屬於你的lightning code。

明天會把之前一些沒講清楚的部分說明清楚。


上一篇
[Day9] pytorch lightning (實作) - 2
下一篇
[Day 11] pytorch lightning 實作 - 4
系列文
菜鳥AI工程師給碩班學弟妹的挑戰30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言