import base64
import datetime
import hashlib
import time
from argparse import ArgumentParser
import multiprocessing
import cv2
import numpy as np
from flask import Flask
from flask import jsonify
from flask import request
from img_gray import process_img
from PIL import Image
import torch
from torch import nn
from torchvision.transforms import Compose, ToTensor,Resize,ColorJitter,Normalize
import torchvision.models as models
import pandas as pd
from R_model_load import Model
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import load_model
import numpy as np
import os
from torch.optim.swa_utils import AveragedModel, update_bn, SWALR
初始化
程式碼
app = Flask(__name__)
# 隊長email
CAPTAIN_EMAIL = 'XXXXXX@gmail.com'
# uuid加密
SALT = '1688'
# CPU運算(關閉GPU)
os.environ["CUDA_VISIBLE_DEVICES"]="-1"
# 4個Model初始化
# Xception
model_Xception = None
# InceptionResNetV2
model_V2 = None
# Densenet201
model_swa = None
# R_SVM_model
model_R = None
# 接收的圖片Log檔
file1 = open('./pic_base64.txt', 'a')
# 官方800字清單
words_path = r'./800_words.txt'
file2 = open(words_path, 'rt', encoding='Big5')
labels = list(file2.read())
# 模型組合之權重與閾值表
# 載入表格
weight_df = pd.read_csv("./model_weight_final.csv", encoding="Big5")
# DenseNet權重
weight_swa = weight_df['wei_ex6'].values
# InceptionResNetV2權重
weight_V2 = weight_df['wei_ex5'].values
# Xception權重
weight_Xception = weight_df['wei_3'].values
API初始化
before_first_request:在處理第一個request前,先執行API初始化,用以載入模型。使用此裝飾器的原因:
程式碼
@app.before_first_request
def init():
# Xception
global model_Xception
model_Xception = load_model('./Xception_retrained_v2.h5')
# InceptionResNetV2
global model_V2
model_V2 = load_model('./InceptionResNetV2.h5')
# DenseNet201
global model_swa
model_densenet = models.densenet201(num_classes=800)
model_path = './swa_densenet201.pth'
model_fang = model_densenet
model_swa = AveragedModel(model_fang)
model_swa.eval()
model_swa.load_state_dict(torch.load(model_path,
map_location=torch.device('cpu')))
# SVM模型
global model_R
MODEL_PATH = "./model_svm_v3"
model_R = Model().load(MODEL_PATH)
print('====================API初始化完成init====================')
產出server_uuid
def generate_server_uuid(input_string):
s = hashlib.sha256()
data = (input_string + SALT).encode("utf-8")
s.update(data)
server_uuid = s.hexdigest()
return server_uuid
def _check_datatype_to_string(prediction):
if isinstance(prediction, str):
return True
raise TypeError('Prediction is not in string type.')
將接收到的圖片轉換格式
將base64編碼轉換成numpy格式。
將圖片去雜訊並轉換成灰階。
紀錄比賽圖片樣本,存入log檔,供後續改善模型之用。
將圖片轉換成模型input格式
程式碼
def base64_to_binary_for_cv2(image_64_encoded):
# base64轉numpy
img_base64_binary = image_64_encoded.encode("utf-8")
img_binary = base64.b64decode(img_base64_binary)
image = cv2.imdecode(np.frombuffer(img_binary, np.uint8),
cv2.IMREAD_COLOR)
# 圖片預處理
image = process_img(image)
image = Image.fromarray(cv2.cvtColor(image,cv2.COLOR_GRAY2RGB))
image_for_tensorflow = np.asarray(image)
# 將接收的圖片,儲存到Log檔
file1.write(image_64_encoded + '\n')
# Xception之input圖片格式
image_for_Xception = cv2.resize(image_for_tensorflow, (80,80),
interpolation=cv2.INTER_CUBIC)
image_for_Xception = np.expand_dims(image_for_Xception, axis=0)
image_for_Xception = image_for_Xception / 255
# InceptionResNetV2之input圖片格式
image_for_V2 = cv2.resize(image_for_tensorflow, (150 , 150),
interpolation=cv2.INTER_CUBIC)
image_for_V2 = np.expand_dims(image_for_V2, axis=0)
image_for_V2 = image_for_V2 / 255
# DenseNet201之input圖片格式
transforms = Compose([ColorJitter(brightness=(1.5, 1.5),
contrast=(6, 6), saturation=(1, 1),
hue=(-0.1, 0.1)), ToTensor(),
Normalize((0.5,), (0.5,))])
image_for_swa = image.resize((80, 80), Image.ANTIALIAS)
image_for_swa = transforms(image_for_swa)
return image_for_Xception,image_for_V2,image_for_swa
辨識手寫中文字
計算3個模型之800字機率,並乘以加權分數。
將800字的加權機率進行加總,取得新的800字機率。
從新的800字機率中,取機率值最大的那個字,做為預測結果。
以閾值判斷,該字是否屬於800字內。若機率大於閾值,輸出該字;反之,則輸出isnull。
檢查預測結果是否為字串。
程式碼
def predict(image_for_Xception,image_for_V2,image_for_swa):
# InceptionResNetV2 predict的機率加權
# 機率向量
pred_V2 = model_V2.predict(image_for_V2)[0]
# 乘上權重的新機率向量
new_V2_prob = pred_V2 * weight_V2
# Xception predict的機率加權
# 機率向量
pred_Xception = model_Xception.predict(image_for_Xception)[0]
# 乘上權重的新機率向量
new_Xception_prob = pred_Xception * weight_Xception
# DenseNet201 predict的機率加權
img = image_for_swa.view(1, 3, 80, 80)
output = model_swa(img)
output = output.view(-1, 800)
output_prob = nn.functional.softmax(output, dim=1)
# 機率向量
output_prob_np = output_prob.cpu().detach().numpy()[0]
# 乘上權重取得新機率向量
new_swa_prob = output_prob_np * weight_swa
# 三個模型向量相加取得新的向量,判定手寫中文字
new_prob = new_swa_prob + new_Xception_prob + new_V2_prob
max_prob = np.max(new_prob)
pred_word = np.argmax(new_prob)
# 讀取該手寫中文字的閾值
judge = labels[pred_word]
mean_prob = weight_df[weight_df["word"] == judge]["mean_prob_new"].values
# 判斷閾值
if max_prob < mean_prob:
prediction = "isnull"
else:
# 考慮加上SVM模型
new_prob_2dim = new_prob[np.newaxis,:]
# 丟入Rmodel預測是否為isnull,1為800字內,2為isnull
pred = model_R.predict(new_prob_2dim)
if pred == 2:
prediction = "isnull"
else:
final_answer = np.argmax(new_prob)
prediction = labels[final_answer]
# 檢查預測結果是否為字串
if _check_datatype_to_string(prediction):
return prediction
API服務(inference 資料傳輸格式:json)
接收API用戶之request。
取出json中image,並轉換成圖片格式。
產出server_uuid:做為回傳時json內容之一。
記錄錯誤log:供後續檢查API服務error之用。
回傳預測結果給主辦方。
程式碼
@app.route('/inference', methods=['POST'])
def inference():
# 接收用戶request
data = request.get_json(force=True)
# 取image base64 encoded,並以cv2轉換格式
image_64_encoded = data['image']
image_for_Xception,image_for_V2,image_for_swa = base64_to_binary_for_cv2(image_64_encoded)
# 產出server_uuid
t = datetime.datetime.now()
ts = str(int(t.utcnow().timestamp()))
server_uuid = generate_server_uuid(CAPTAIN_EMAIL + ts)
# 記錄API錯誤log
try:
answer = predict(image_for_Xception,
image_for_V2,
image_for_swa)
except TypeError as type_error:
raise type_error
except Exception as e:
raise e
server_timestamp = time.time()
# 回傳預測結果給用戶
return jsonify({'esun_uuid': data['esun_uuid'],
'server_uuid': server_uuid,
'answer': answer,
'server_timestamp': server_timestamp})
if __name__ == "__main__":
arg_parser = ArgumentParser(usage='Usage: python ' + __file__ +
' [--port <port>] [--help]')
arg_parser.add_argument('-p', '--port', default=8080, help='port')
arg_parser.add_argument('-d', '--debug', default=True, help='debug')
options = arg_parser.parse_args()
app.run(host='0.0.0.0', port=options.port, debug=options.debug)
整個比賽下來收穫很多,雖然成績不是說特別好,但對於第一次參賽的我們,覺得這個經驗非常值得。
結束了整個流程,明天來檢討哪些地方可以做改善,以及參考得獎隊伍的做法。