iT邦幫忙

2024 iThome 鐵人賽

DAY 13
0
自我挑戰組

菜鳥AI工程師給碩班學弟妹的挑戰系列 第 13

[Day 13] pytorch lightning (predict_step)

  • 分享至 

  • xImage
  •  

非常感謝昨天有網友分享自己訓練的loss圖,不過我沒辦法回復,所以就此感謝,也歡迎大家一起分享。

前面大致上把模型訓練跟一些基礎的觀念講完了,現在我們來講講對碩班同學不重要,但實際上班非常重要的東西。

我看蠻多碩士生,當論文數據跑出來之後就結束了,對我來說效果再怎麼好,沒有實際應用也是空談,所以這裡希望碩士生加減能碰一下deploy model(部屬模型),主要就是將你的code包成一個可以東西放到任何機器上跑。

接下來這幾天會教重寫inference的code,fastapi,docker build image,經過這些你就可以deploy model囉~~

1. infer.py(predict_step)

一般在deploy model的時候,我喜歡將原本main.py的code重寫一次,並叫成infer.py,如果需要我還會把dataloader.py重寫,只留下inference需要用的或改成inference要用的,我一開始學的時候不知道甚麼是inference,翻譯成中文是"推論",那時候一頭霧水,簡單來說就是透過你的model跑出結果,這個過程稱為inference。

會發現我只留下相關的東西(metrics跟loss就不需要囉),之前我沒有特別說到,為甚麼之前forward裡面都沒寫東西呢?? 主要是我們的forward已經在model.py都做了,所以就不太需要寫forward,如果要寫也可以像下面一樣,在predict_step當中的self(batch),指的是呼叫forward能直接簡寫成self(),主要是要講這個觀念。

接下來trainer.predict我自己喜歡從外面帶入Dataloader,並且回傳結果。

from torch.utils.data import DataLoader
import lightning as pl
import torch

from model import MNISTClassifier
from dataloader_infer import CustomDataset

class example(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = MNISTClassifier()

    def forward(self, batch):
        preds = self.model(batch)
        return preds
    
    def predict_step(self, batch, idx):
        pred = self(batch).argmax(dim = -1)
        return pred.cpu().detach()
    
if __name__ == "__main__":
    model = example()
    batch_size = 1
    trainer = pl.Trainer(logger = False)
    ckpt_path = "./last.ckpt"
    ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))["state_dict"]
    model.load_state_dict(ckpt, strict = False)
    model.eval()
    img_path = './0.jpg'
    dataset = CustomDataset(img_path)
    dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = False)
    pred = trainer.predict(model, dataloader, return_predictions = True)
    print(pred)

2. dataloader_infer.py

也會修改一下這個code,因為你不會知道label是甚麼,所以少回傳一個參數,再來以這個範例,我們一次只辨識一張,所以原先是開txt檔的過程我們也省略,直接給他img_path

from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms

class CustomDataset(Dataset):
    def __init__(self, img_path):
        self.data = [img_path]

        self.transform = transforms.Compose([
            transforms.Resize((28, 28)),  # 確保圖片大小一致
            transforms.ToTensor(),        # 轉換為PyTorch張量
            transforms.Normalize((0.5, ), (0.5, ))  # 標準化
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        path = self.data[idx]
        image = Image.open(path).convert('L')  # MNIST是灰度圖,轉換為'L'模式
        image = self.transform(image)

        return image


if __name__ == "__main__":
    img_path = './0.jpg'
    unit_test = CustomDataset(img_path)
    for idx, (image, ) in enumerate(unit_test):
        print(f'image: {image.size()}')


今天就到這裡囉~
主要就是精簡化的code,只保留inference需要的而已,雖然需要額外在寫,不過我覺得精簡過後的簡單明瞭,就給大家參考囉~


上一篇
[Day 12] loos 曲線
下一篇
[Day 14] fastapi 介紹
系列文
菜鳥AI工程師給碩班學弟妹的挑戰30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言