非常感謝昨天有網友分享自己訓練的loss圖,不過我沒辦法回復,所以就此感謝,也歡迎大家一起分享。
前面大致上把模型訓練跟一些基礎的觀念講完了,現在我們來講講對碩班同學不重要,但實際上班非常重要的東西。
我看蠻多碩士生,當論文數據跑出來之後就結束了,對我來說效果再怎麼好,沒有實際應用也是空談,所以這裡希望碩士生加減能碰一下deploy model(部屬模型),主要就是將你的code包成一個可以東西放到任何機器上跑。
接下來這幾天會教重寫inference的code,fastapi,docker build image,經過這些你就可以deploy model囉~~
一般在deploy model的時候,我喜歡將原本main.py的code重寫一次,並叫成infer.py,如果需要我還會把dataloader.py重寫,只留下inference需要用的或改成inference要用的,我一開始學的時候不知道甚麼是inference,翻譯成中文是"推論",那時候一頭霧水,簡單來說就是透過你的model跑出結果,這個過程稱為inference。
會發現我只留下相關的東西(metrics跟loss就不需要囉),之前我沒有特別說到,為甚麼之前forward裡面都沒寫東西呢?? 主要是我們的forward已經在model.py都做了,所以就不太需要寫forward,如果要寫也可以像下面一樣,在predict_step當中的self(batch),指的是呼叫forward能直接簡寫成self(),主要是要講這個觀念。
接下來trainer.predict我自己喜歡從外面帶入Dataloader,並且回傳結果。
from torch.utils.data import DataLoader
import lightning as pl
import torch
from model import MNISTClassifier
from dataloader_infer import CustomDataset
class example(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = MNISTClassifier()
def forward(self, batch):
preds = self.model(batch)
return preds
def predict_step(self, batch, idx):
pred = self(batch).argmax(dim = -1)
return pred.cpu().detach()
if __name__ == "__main__":
model = example()
batch_size = 1
trainer = pl.Trainer(logger = False)
ckpt_path = "./last.ckpt"
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))["state_dict"]
model.load_state_dict(ckpt, strict = False)
model.eval()
img_path = './0.jpg'
dataset = CustomDataset(img_path)
dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = False)
pred = trainer.predict(model, dataloader, return_predictions = True)
print(pred)
也會修改一下這個code,因為你不會知道label是甚麼,所以少回傳一個參數,再來以這個範例,我們一次只辨識一張,所以原先是開txt檔的過程我們也省略,直接給他img_path
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class CustomDataset(Dataset):
def __init__(self, img_path):
self.data = [img_path]
self.transform = transforms.Compose([
transforms.Resize((28, 28)), # 確保圖片大小一致
transforms.ToTensor(), # 轉換為PyTorch張量
transforms.Normalize((0.5, ), (0.5, )) # 標準化
])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
path = self.data[idx]
image = Image.open(path).convert('L') # MNIST是灰度圖,轉換為'L'模式
image = self.transform(image)
return image
if __name__ == "__main__":
img_path = './0.jpg'
unit_test = CustomDataset(img_path)
for idx, (image, ) in enumerate(unit_test):
print(f'image: {image.size()}')
今天就到這裡囉~
主要就是精簡化的code,只保留inference需要的而已,雖然需要額外在寫,不過我覺得精簡過後的簡單明瞭,就給大家參考囉~