前情提要: 昨天剛好看到LitServer,就想說來試試看,實際的架構確實很乾淨,但需要熟悉一下。
這裡我參照官方範例,改了一下bert model,這個bert model downstream是做文字情感分析的,輸入一段文字可以判斷是正向 負向 中性。
今天我使用start debugging 去看每一步執行,流程確實跟我昨天想的完全一樣。
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from litserve import LitAPI, LitServer
class BERTLitAPI(LitAPI):
def setup(self, device):
"""
Load the tokenizer and model, and move the model to the specified device.
"""
self.tokenizer = AutoTokenizer.from_pretrained("MarieAngeA13/Sentiment-Analysis-BERT")
self.model = AutoModelForSequenceClassification.from_pretrained("MarieAngeA13/Sentiment-Analysis-BERT")
# Move model to the device (e.g., CPU, GPU)
self.model.to(device)
# Set the model in evaluation mode
self.model.eval()
def decode_request(self, request):
"""
Preprocess the request data (tokenize)
"""
# Assuming request is a dictionary with a "text" field
inputs = self.tokenizer(request["text"], return_tensors="pt", padding=True, truncation=True, max_length=512)
return inputs
def predict(self, inputs):
"""
Perform the inference
"""
with torch.no_grad():
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
outputs = self.model(**inputs)
return outputs.logits
def encode_response(self, logits):
"""
Process the model output into a response dictionary
"""
# Convert logits to probabilities
probabilities = torch.nn.functional.softmax(logits, dim=-1)
# Assuming you're working with a binary classifier for simplicity
response = {
"negative": probabilities[:, 0].item(),
"neutral" : probabilities[:, 1].item(),
"positive": probabilities[:, 2].item(),
}
return response
if __name__ == "__main__":
api = BERTLitAPI()
server = LitServer(api, accelerator = 'cpu')
server.run(port = 8000)
import requests
text_arr = ["This is a bad example", "This is a example", "This is a great example"]
for text in text_arr:
response = requests.post(
"http://127.0.0.1:8000/predict",
json = {"text": text}
)
print(f"text: {text}, Response:{response.text}\n")
test_upload_server.py
這裡一樣測試怎麼上傳檔案,官方教學這部分沒有寫到,我們一樣拿之前的0.jpg來測試,這個存檔的方式也跟之前一樣。
[8/29更新]
今天查官方文件發現有清楚的範例說明,還有對應server跟client,所以補上網址,有興趣的同學可以試試看~~
https://lightning.ai/docs/litserve/features/request-response-format
from fastapi import Request, Response
from litserve import LitAPI, LitServer
class SimpleFileLitAPI(LitAPI):
def setup(self, device):
self.model = lambda x: x**2
def decode_request(self, request: Request):
with open('0_upload.jpg', "wb+") as file_object:
file_object.write(request["file"].file.read())
return 1
def predict(self, x):
return self.model(x)
def encode_response(self, output) -> Response:
return {"output": output}
if __name__ == "__main__":
server = LitServer(SimpleFileLitAPI(), accelerator="cpu")
server.run(port = 8000)
test_upload_client.py
這裡我們透過requests來上傳檔案,一樣是用multipart/form-data的格式,執行完client.py,確實有在目錄下得到0_upload.jpg
import requests
url = "http://127.0.0.1:8000/predict"
files = {'file': ('0.jpg', open('0.jpg', 'rb'), 'multipart/form-data')}
response = requests.post(url, files=files)
print(response.text)
今天就先到這裡囉~~
這個githuh才剛開始,所以有些教學並沒有寫,所以就自己先嘗試嘗試,不過將程式分成一個個block我自己是蠻喜歡的,就再研究研究,畢竟他們持續有在更新中。