iT邦幫忙

2021 iThome 鐵人賽

DAY 29
0
自我挑戰組

資料分析及AI深度學習-簡單基礎實作系列 第 29

DAY29:開啟API服務(完賽)

  • 分享至 

  • xImage
  •  

部署及開啟API服務-flask

  1. 導入套件
import base64
import datetime
import hashlib
import time
from argparse import ArgumentParser
import multiprocessing
import cv2
import numpy as np
from flask import Flask
from flask import jsonify
from flask import request
from img_gray import process_img
from PIL import Image
import torch
from torch import nn
from torchvision.transforms import Compose, ToTensor,Resize,ColorJitter,Normalize
import torchvision.models as models
import pandas as pd
from R_model_load import Model
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import load_model
import numpy as np
import os
from torch.optim.swa_utils import AveragedModel, update_bn, SWALR
  1. 初始化

    • 隊長Email
    • uuid加密
    • CPU運算:因GCP免費試用,開啟的VM無GPU,故以CPU運算。
    • 4個Model初始化:3個影像辨識模型+1個SVM模型(用以判斷isnull)。
    • 接收圖片的Log與官方800字清單
    • 模型組合之權重與閾值表

    程式碼

    app = Flask(__name__)
    
    # 隊長email
    CAPTAIN_EMAIL = 'XXXXXX@gmail.com'
    
    # uuid加密
    SALT = '1688'
    
    # CPU運算(關閉GPU)
    os.environ["CUDA_VISIBLE_DEVICES"]="-1"
    
    # 4個Model初始化
    # Xception
    model_Xception = None
    # InceptionResNetV2
    model_V2 = None
    # Densenet201
    model_swa = None
    # R_SVM_model
    model_R = None
    
    # 接收的圖片Log檔
    file1 = open('./pic_base64.txt', 'a')
    # 官方800字清單
    words_path = r'./800_words.txt'
    file2 = open(words_path, 'rt', encoding='Big5')
    labels = list(file2.read())
    
    # 模型組合之權重與閾值表
    # 載入表格
    weight_df = pd.read_csv("./model_weight_final.csv", encoding="Big5")
    # DenseNet權重
    weight_swa = weight_df['wei_ex6'].values
    # InceptionResNetV2權重
    weight_V2 = weight_df['wei_ex5'].values
    # Xception權重
    weight_Xception = weight_df['wei_3'].values   
    
  2. API初始化

    • before_first_request:在處理第一個request前,先執行API初始化,用以載入模型。使用此裝飾器的原因:

      • 一開始沒有用多線程,我們的模型又較大,在接收圖片時會處理較慢,導致無法一次接收處理多張圖片,讓我們比賽有一天只回傳不到一半的答案。
      • 後來使用多線程發現,tensorflow的模型會無法讀取到,後來找到before_first_request這個解決方式。
    • 程式碼

      @app.before_first_request
      def init():
         # Xception
         global model_Xception
         model_Xception = load_model('./Xception_retrained_v2.h5')
      
         # InceptionResNetV2
         global model_V2
         model_V2 = load_model('./InceptionResNetV2.h5')
      
         # DenseNet201
         global model_swa
         model_densenet = models.densenet201(num_classes=800)
         model_path = './swa_densenet201.pth'
         model_fang = model_densenet
         model_swa = AveragedModel(model_fang)
         model_swa.eval()
         model_swa.load_state_dict(torch.load(model_path,
                                   map_location=torch.device('cpu')))
      
         # SVM模型
         global model_R
         MODEL_PATH = "./model_svm_v3"
         model_R = Model().load(MODEL_PATH)
         print('====================API初始化完成init====================')
      
  3. 產出server_uuid

def generate_server_uuid(input_string):
    s = hashlib.sha256()
    data = (input_string + SALT).encode("utf-8")
    s.update(data)
    server_uuid = s.hexdigest()
    return server_uuid
  1. 檢查預測結果是否為字串:供後續輸出預測結果之前,判定資料型態。
def _check_datatype_to_string(prediction):
    if isinstance(prediction, str):
        return True
    raise TypeError('Prediction is not in string type.')
  1. 將接收到的圖片轉換格式

    • 將base64編碼轉換成numpy格式。

    • 將圖片去雜訊並轉換成灰階。

    • 紀錄比賽圖片樣本,存入log檔,供後續改善模型之用。

    • 將圖片轉換成模型input格式

    • 程式碼

    def base64_to_binary_for_cv2(image_64_encoded):
        # base64轉numpy
        img_base64_binary = image_64_encoded.encode("utf-8")
        img_binary = base64.b64decode(img_base64_binary)
        image = cv2.imdecode(np.frombuffer(img_binary, np.uint8),
                             cv2.IMREAD_COLOR)
    
        # 圖片預處理
        image = process_img(image)
        image = Image.fromarray(cv2.cvtColor(image,cv2.COLOR_GRAY2RGB))
        image_for_tensorflow = np.asarray(image)
    
        # 將接收的圖片,儲存到Log檔
        file1.write(image_64_encoded + '\n')
    
        # Xception之input圖片格式
        image_for_Xception = cv2.resize(image_for_tensorflow, (80,80),
                                        interpolation=cv2.INTER_CUBIC)
        image_for_Xception = np.expand_dims(image_for_Xception, axis=0)
        image_for_Xception = image_for_Xception / 255
    
        # InceptionResNetV2之input圖片格式
        image_for_V2 = cv2.resize(image_for_tensorflow, (150 , 150),
                                  interpolation=cv2.INTER_CUBIC)
        image_for_V2 = np.expand_dims(image_for_V2, axis=0)
        image_for_V2 = image_for_V2 / 255
    
        # DenseNet201之input圖片格式
        transforms = Compose([ColorJitter(brightness=(1.5, 1.5),
                              contrast=(6, 6), saturation=(1, 1),
                              hue=(-0.1, 0.1)), ToTensor(),
                              Normalize((0.5,), (0.5,))])
        image_for_swa = image.resize((80, 80), Image.ANTIALIAS)
        image_for_swa = transforms(image_for_swa)
    
        return image_for_Xception,image_for_V2,image_for_swa     
    
  2. 辨識手寫中文字

    • 計算3個模型之800字機率,並乘以加權分數。

    • 將800字的加權機率進行加總,取得新的800字機率。

    • 從新的800字機率中,取機率值最大的那個字,做為預測結果。

    • 以閾值判斷,該字是否屬於800字內。若機率大於閾值,輸出該字;反之,則輸出isnull。

    • 檢查預測結果是否為字串。

    • 程式碼

    def predict(image_for_Xception,image_for_V2,image_for_swa):
        # InceptionResNetV2 predict的機率加權
        # 機率向量
        pred_V2 = model_V2.predict(image_for_V2)[0]
        # 乘上權重的新機率向量
        new_V2_prob = pred_V2 * weight_V2
    
        # Xception predict的機率加權
        # 機率向量
        pred_Xception = model_Xception.predict(image_for_Xception)[0] 
        # 乘上權重的新機率向量
        new_Xception_prob = pred_Xception * weight_Xception 
    
        # DenseNet201 predict的機率加權
        img = image_for_swa.view(1, 3, 80, 80)
        output = model_swa(img)
        output = output.view(-1, 800)
        output_prob = nn.functional.softmax(output, dim=1)
        # 機率向量
        output_prob_np = output_prob.cpu().detach().numpy()[0]
        # 乘上權重取得新機率向量
        new_swa_prob = output_prob_np * weight_swa 
    
        # 三個模型向量相加取得新的向量,判定手寫中文字
        new_prob = new_swa_prob + new_Xception_prob + new_V2_prob
        max_prob = np.max(new_prob)
        pred_word = np.argmax(new_prob)
    
        # 讀取該手寫中文字的閾值
        judge = labels[pred_word]
        mean_prob = weight_df[weight_df["word"] == judge]["mean_prob_new"].values
    
        # 判斷閾值
        if max_prob < mean_prob:
            prediction = "isnull"
        else:
            # 考慮加上SVM模型
            new_prob_2dim = new_prob[np.newaxis,:]
            # 丟入Rmodel預測是否為isnull,1為800字內,2為isnull
            pred = model_R.predict(new_prob_2dim)
            if pred == 2:
                prediction = "isnull"
            else:
                final_answer = np.argmax(new_prob)
                prediction = labels[final_answer]
    
        # 檢查預測結果是否為字串
        if _check_datatype_to_string(prediction):
            return prediction
    
  3. API服務(inference 資料傳輸格式:json)

    • 接收API用戶之request。

    • 取出json中image,並轉換成圖片格式。

    • 產出server_uuid:做為回傳時json內容之一。

    • 記錄錯誤log:供後續檢查API服務error之用。

    • 回傳預測結果給主辦方。

    • 程式碼

    @app.route('/inference', methods=['POST'])
    def inference():
        # 接收用戶request
        data = request.get_json(force=True)
    
        # 取image base64 encoded,並以cv2轉換格式
        image_64_encoded = data['image']
        image_for_Xception,image_for_V2,image_for_swa = base64_to_binary_for_cv2(image_64_encoded)
    
        # 產出server_uuid
        t = datetime.datetime.now()
        ts = str(int(t.utcnow().timestamp()))
        server_uuid = generate_server_uuid(CAPTAIN_EMAIL + ts)
    
        # 記錄API錯誤log
        try:
            answer = predict(image_for_Xception,
                             image_for_V2,
                             image_for_swa)
        except TypeError as type_error:
            raise type_error
        except Exception as e:
            raise e
        server_timestamp = time.time()
    
        # 回傳預測結果給用戶
        return jsonify({'esun_uuid': data['esun_uuid'],
                        'server_uuid': server_uuid,
                        'answer': answer,
                        'server_timestamp': server_timestamp})
    
    if __name__ == "__main__":
    
        arg_parser = ArgumentParser(usage='Usage: python ' + __file__ +
                                   ' [--port <port>] [--help]')
        arg_parser.add_argument('-p', '--port', default=8080, help='port')
        arg_parser.add_argument('-d', '--debug', default=True, help='debug')
        options = arg_parser.parse_args()
    
        app.run(host='0.0.0.0', port=options.port, debug=options.debug)
    

小結

  • 整個比賽下來收穫很多,雖然成績不是說特別好,但對於第一次參賽的我們,覺得這個經驗非常值得。

  • 結束了整個流程,明天來檢討哪些地方可以做改善,以及參考得獎隊伍的做法。


上一篇
DAY28:VM安裝套件以及GCP注意事項
下一篇
DAY30:賽後心得檢討
系列文
資料分析及AI深度學習-簡單基礎實作30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言