前情提要: 昨天稍微介紹了一下fastapi如何上傳檔案,今天我們要上傳檔案後辨識,並回完結果。
其實就是將我們當初寫的infer.py裡的__main__分成兩部分搬過來,在最前面先load model進來,然後在api裡面創建Dataloader,最後送進去model做inference,然後回傳結果。
import logging
from fastapi import FastAPI, UploadFile, File
from fastapi import FastAPI
from infer import example
import torch
import lightning as pl
from dataloader_infer import CustomDataset
from torch.utils.data import DataLoader
FORMAT = '%(asctime)s %(levelname)s [%(filename)s] %(message)s'
logging.basicConfig(level = logging.INFO, format = FORMAT)
app = FastAPI(title = 'deploy mnist classifer model')
# [Day 15]
logging.info('Loading model...')
try:
model = example()
batch_size = 1
trainer = pl.Trainer(logger = False)
ckpt_path = "./last.ckpt"
ckpt = torch.load(ckpt_path, map_location = torch.device('cpu'))["state_dict"]
model.load_state_dict(ckpt, strict = False)
model.eval()
logging.info('Success loading model.')
except:
logging.error('Loading model error, please try again')
@app.get("/")
async def root():
return {"message": "Hello World"}
@app.post("/api/v1/mnist")
async def mnist_classifer(
file: UploadFile = File(None),
):
'''
Args:
file: 上傳的檔案
'''
try:
file_location = "./temp.jpg"
with open(file_location, "wb+") as file_object:
file_object.write(file.file.read())
logging.info(f"Successfully uploaded {file.filename}")
except Exception as e:
logging.DEBUG(e)
return {"error": f"failed uploaded {file_location}"}
# [Day 15]
try:
img_path = file_location
dataset = CustomDataset(img_path)
dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = False)
pred = trainer.predict(model, dataloader, return_predictions = True)
print(pred)
return {"message": str(pred[0].item())}
except Exception as e:
logging.DEBUG(f"Failed model inference, error: {e}")
return {"error": f"Failed model inference, error {e}"}
你就可以隨便拖拉jpg到選擇檔案那邊然後按Execute,就會回傳model辨識完的結果囉~~
今天的非常簡單,只是將程式做合併而已,之後你只需把model的部分換成你的,整體的格式大致上就是這樣,雖然簡單但卻十分重要。
基本上上面的範例已經非常足夠了,如果需要更進一步,就變成還要檢查是否有上傳檔案,或上傳檔案是否有符合格式,甚至一張張辨識太慢了,我想要一口氣辨識十張之類的,可以想想看怎麼做~~