導入套件
import base64
import datetime
import hashlib
import time
from argparse import ArgumentParser
import multiprocessing
import cv2
import numpy as np
from flask import Flask
from flask import jsonify
from flask import request
from img_gray import process_img
from PIL import Image
import torch
from torch import nn
from torchvision.transforms import Compose, ToTensor,Resize,ColorJitter,Normalize
import torchvision.models as models
import pandas as pd
from R_model_load import Model
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import load_model
import numpy as np
import os
from torch.optim.swa_utils import AveragedModel, update_bn, SWALR
模型初始化資料
2.1 資料內容
2.2 程式碼
app = Flask(__name__)
# 隊長email
CAPTAIN_EMAIL = 'XXXXXX@gmail.com'
# uuid加密
SALT = '1688'
# CPU運算(關閉GPU)
os.environ["CUDA_VISIBLE_DEVICES"]="-1"
# 4個Model初始化
# Xception
model_Xception = None
# InceptionResNetV2
model_V2 = None
# Densenet201
model_swa = None
# R_SVM_model
model_R = None
# 接收的圖片Log檔
file1 = open('./pic_base64.txt', 'a')
# 官方800字清單
words_path = r'./800_words.txt'
file2 = open(words_path, 'rt', encoding='Big5')
labels = list(file2.read())
# 模型組合之權重與閾值表
# 載入表格
weight_df = pd.read_csv("./model_weight_final.csv", encoding="Big5")
# DenseNet權重
weight_swa = weight_df['wei_ex6'].values
# InceptionResNetV2權重
weight_V2 = weight_df['wei_ex5'].values
# Xception權重
weight_Xception = weight_df['wei_3'].values
API初始化
3.1 before_first_request:在處理第一個request前,先執行API初始化,用以載入模型。
3.2 程式碼
@app.before_first_request
def init():
# Xception
global model_Xception
model_Xception = load_model('./Xception_retrained_v2.h5')
# InceptionResNetV2
global model_V2
model_V2 = load_model('./InceptionResNetV2.h5')
# DenseNet201
global model_swa
model_densenet = models.densenet201(num_classes=800)
model_path = './swa_densenet201.pth'
model_fang = model_densenet
model_swa = AveragedModel(model_fang)
model_swa.eval()
model_swa.load_state_dict(torch.load(model_path,
map_location=torch.device('cpu')))
# SVM模型
global model_R
MODEL_PATH = "./model_svm_v3"
model_R = Model().load(MODEL_PATH)
print('====================API初始化完成init====================')
產出server_uuid
def generate_server_uuid(input_string):
s = hashlib.sha256()
data = (input_string + SALT).encode("utf-8")
s.update(data)
server_uuid = s.hexdigest()
return server_uuid
檢查預測結果是否為字串:供後續輸出預測結果之前,判定資料型態。
def _check_datatype_to_string(prediction):
if isinstance(prediction, str):
return True
raise TypeError('Prediction is not in string type.')
將接收到的圖片轉換格式
6.1 流程
6.2 程式碼
def base64_to_binary_for_cv2(image_64_encoded):
# base64轉numpy
img_base64_binary = image_64_encoded.encode("utf-8")
img_binary = base64.b64decode(img_base64_binary)
image = cv2.imdecode(np.frombuffer(img_binary, np.uint8),
cv2.IMREAD_COLOR)
# 圖片預處理
image = process_img(image)
image = Image.fromarray(cv2.cvtColor(image,cv2.COLOR_GRAY2RGB))
image_for_tensorflow = np.asarray(image)
# 將接收的圖片,儲存到Log檔
file1.write(image_64_encoded + '\n')
# Xception之input圖片格式
image_for_Xception = cv2.resize(image_for_tensorflow, (80,80),
interpolation=cv2.INTER_CUBIC)
image_for_Xception = np.expand_dims(image_for_Xception, axis=0)
image_for_Xception = image_for_Xception / 255
# InceptionResNetV2之input圖片格式
image_for_V2 = cv2.resize(image_for_tensorflow, (150 , 150),
interpolation=cv2.INTER_CUBIC)
image_for_V2 = np.expand_dims(image_for_V2, axis=0)
image_for_V2 = image_for_V2 / 255
# DenseNet201之input圖片格式
transforms = Compose([ColorJitter(brightness=(1.5, 1.5),
contrast=(6, 6), saturation=(1, 1),
hue=(-0.1, 0.1)), ToTensor(),
Normalize((0.5,), (0.5,))])
image_for_swa = image.resize((80, 80), Image.ANTIALIAS)
image_for_swa = transforms(image_for_swa)
return image_for_Xception,image_for_V2,image_for_swa
模型辨識手寫中文字
7.1 流程
7.2 程式碼
def predict(image_for_Xception,image_for_V2,image_for_swa):
# InceptionResNetV2 predict的機率加權
# 機率向量
pred_V2 = model_V2.predict(image_for_V2)[0]
# 乘上權重的新機率向量
new_V2_prob = pred_V2 * weight_V2
# Xception predict的機率加權
# 機率向量
pred_Xception = model_Xception.predict(image_for_Xception)[0]
# 乘上權重的新機率向量
new_Xception_prob = pred_Xception * weight_Xception
# DenseNet201 predict的機率加權
img = image_for_swa.view(1, 3, 80, 80)
output = model_swa(img)
output = output.view(-1, 800)
output_prob = nn.functional.softmax(output, dim=1)
# 機率向量
output_prob_np = output_prob.cpu().detach().numpy()[0]
# 乘上權重取得新機率向量
new_swa_prob = output_prob_np * weight_swa
# 三個模型向量相加取得新的向量,判定手寫中文字
new_prob = new_swa_prob + new_Xception_prob + new_V2_prob
max_prob = np.max(new_prob)
pred_word = np.argmax(new_prob)
# 讀取該手寫中文字的閾值
judge = labels[pred_word]
mean_prob = weight_df[weight_df["word"] == judge]["mean_prob_new"].values
# 判斷閾值
if max_prob < mean_prob:
prediction = "isnull"
else:
# 考慮加上SVM模型
new_prob_2dim = new_prob[np.newaxis,:]
# 丟入Rmodel預測是否為isnull,1為800字內,2為isnull
pred = model_R.predict(new_prob_2dim)
if pred == 2:
prediction = "isnull"
else:
final_answer = np.argmax(new_prob)
prediction = labels[final_answer]
# 檢查預測結果是否為字串
if _check_datatype_to_string(prediction):
return prediction
API服務(inference 資料傳輸格式:json)
8.1 流程
8.2 程式碼
@app.route('/inference', methods=['POST'])
def inference():
# 接收用戶request
data = request.get_json(force=True)
# 取image base64 encoded,並以cv2轉換格式
image_64_encoded = data['image']
image_for_Xception,image_for_V2,image_for_swa = base64_to_binary_for_cv2(image_64_encoded)
# 產出server_uuid
t = datetime.datetime.now()
ts = str(int(t.utcnow().timestamp()))
server_uuid = generate_server_uuid(CAPTAIN_EMAIL + ts)
# 記錄API錯誤log
try:
answer = predict(image_for_Xception,
image_for_V2,
image_for_swa)
except TypeError as type_error:
raise type_error
except Exception as e:
raise e
server_timestamp = time.time()
# 回傳預測結果給用戶
return jsonify({'esun_uuid': data['esun_uuid'],
'server_uuid': server_uuid,
'answer': answer,
'server_timestamp': server_timestamp})
if __name__ == "__main__":
arg_parser = ArgumentParser(usage='Usage: python ' + __file__ +
' [--port <port>] [--help]')
arg_parser.add_argument('-p', '--port', default=8080, help='port')
arg_parser.add_argument('-d', '--debug', default=True, help='debug')
options = arg_parser.parse_args()
app.run(host='0.0.0.0', port=options.port, debug=options.debug)
啟用API服務
9.1 如何啟用API服務
9.2 成功啟用API服務(如下圖)
讓我們繼續看下去...