iT邦幫忙

第 12 屆 iThome 鐵人賽

DAY 13
1
AI & Data

輕鬆掌握 Keras 及相關應用系列 第 13

Day 13:測試 CNN 的桌面程式

  • 分享至 

  • xImage
  •  

前言

前一篇利用【資料增補】(Data Augmentation)技術,擴增訓練資料,對於準確率是否有真的提升嗎? 這次,我們就來實作一個具有書寫介面的桌面程式,看看 CNN + Data Augmentation 到底行不行啊。

程式開發分兩個步驟:

  1. 視窗介面:利用 Tkinter 套件的 Canvas 元件,提供以滑鼠書寫阿拉伯數字的功能,並含存檔功能。
  2. 模型預測:將書寫的影像交由 CNN 模型預測,並含模型訓練。

https://ithelp.ithome.com.tw/upload/images/20200912/200019767i2jyecyCF.png
圖一. 視窗介面

視窗介面

首先,撰寫一個視窗介面,無辨識功能,程式碼如下,檔案名稱為 13_01_Canvas.py。

from tkinter import *
from tkinter import filedialog

from PIL import ImageDraw, Image, ImageGrab
import numpy as np
from skimage import color
from skimage import io
import os
import io

class Paint(object):
    # 類別初始化函數
    def __init__(self):
        self.root = Tk()

        #defining Canvas
        self.c = Canvas(self.root, bg='white', width=280, height=280)
        
        self.image1 = Image.new('RGB', (280, 280), color = 'white')
        self.draw = ImageDraw.Draw(self.image1) 

        self.c.grid(row=1, columnspan=6)

        # 建立【辨識】按鈕
        self.classify_button = Button(self.root, text='辨識', command=lambda:self.classify(self.c))
        self.classify_button.grid(row=0, column=0, columnspan=2, sticky='EWNS')

        # 建立【清畫面】按鈕
        self.clear = Button(self.root, text='清畫面', command=self.clear)
        self.clear.grid(row=0, column=2, columnspan=2, sticky='EWNS')

        # 建立【存檔】按鈕
        self.savefile = Button(self.root, text='存檔', command=self.savefile)
        self.savefile.grid(row=0, column=4, columnspan=2, sticky='EWNS')

        # 建立【預測】文字框
        self.prediction_text = Text(self.root, height=2, width=10)
        self.prediction_text.grid(row=2, column=4, columnspan=2)

        # self.model = self.loadModel()
        
        # 定義滑鼠事件處理函數
        self.setup()
        
        # 監聽事件
        self.root.mainloop()

    # 滑鼠事件處理函數
    def setup(self):
        self.old_x = None
        self.old_y = None
        self.line_width = 15
        self.color = 'black'
        
        # 定義滑鼠事件處理函數,包括移動滑鼠及鬆開滑鼠按鈕
        self.c.bind('<B1-Motion>', self.paint)
        self.c.bind('<ButtonRelease-1>', self.reset)

    # 移動滑鼠 處理函數
    def paint(self, event):
        paint_color = self.color
        if self.old_x and self.old_y:
            self.c.create_line(self.old_x, self.old_y, event.x, event.y,
                               width=self.line_width, fill=paint_color,
                               capstyle=ROUND, smooth=TRUE, splinesteps=36)
            # 顯示設定>100%,抓到的區域會變小
            # 畫圖同時寫到記憶體,避免螢幕字型放大,造成抓到的畫布區域不足
            self.draw.line((self.old_x, self.old_y, event.x, event.y), fill='black', width=5)

        self.old_x = event.x
        self.old_y = event.y

    # 鬆開滑鼠按鈕 處理函數
    def reset(self, event):
        self.old_x, self.old_y = None, None

    # 【清畫面】處理函數
    def clear(self):
        self.c.delete("all")
        self.image1 = Image.new('RGB', (280, 280), color = 'white')
        self.draw = ImageDraw.Draw(self.image1) 
        self.prediction_text.delete("1.0", END)

    # 【存檔】處理函數
    def savefile(self):
        f = filedialog.asksaveasfilename( defaultextension=".png", filetypes = [("png file",".png")])
        if f is None: # asksaveasfile return `None` if dialog closed with "cancel".
            return
        #print(f)
        self.image1.save(f)

    # 【辨識】處理函數
    def classify(self, widget):
        pass
        
if __name__ == '__main__':
    Paint()

每段程式均有註解,筆者就不解釋了,其中【辨識】處理函數(classify) 涉及模型預測留待下一段處理。特別注意,筆者發現如果電腦的文字顯示設定>100%,如下圖,直接擷取視窗的Canvas,抓到的區域會不完整,所以,筆者直接抓記憶體,但是會有速度跟不上,字體會有殘缺的問題,如下圖。
https://ithelp.ithome.com.tw/upload/images/20200913/200019764WwsRowOSh.png
圖二. 文字顯示設定

https://ithelp.ithome.com.tw/upload/images/20200913/20001976uWe9g0V9zU.png
圖三. 記憶體繪製,速度跟不上,字體會有殘缺

程式執行方式為:
python 13_01_Canvas.py

模型預測

先寫一個模組,供主程式呼叫,功能包括取得MNIST訓練資料(getData)、訓練模型(trainModel)、載入模型
(loadModel),檔案名稱為 cnn_class.py,程式碼如下:

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import *
from tensorflow.keras.utils import to_categorical
import os

# 取得 MNIST 資料
def getData():
    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
    img_rows, img_cols = 28, 28

    y_train = to_categorical(y_train, num_classes=10)
    y_test = to_categorical(y_test, num_classes=10)

    # CNN 需加一維
    X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
    X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)

    return X_train/255, y_train, X_test/255, y_test

# 訓練模型
def trainModel(X_train, y_train, X_test, y_test):
    batch_size = 64
    epochs = 15

    model = tf.keras.models.Sequential()

    model.add(Conv2D(filters=32, kernel_size=(5,5), activation='relu', input_shape=(28,28,1)))
    model.add(Conv2D(filters=32, kernel_size=(5,5), activation='relu'))
    model.add(MaxPool2D(pool_size=(2, 2)))
    model.add(Dropout(rate=0.25))

    model.add(Conv2D(filters=64, kernel_size=(3, 3), activation='relu'))
    model.add(Conv2D(filters=64, kernel_size=(3, 3), activation='relu'))
    model.add(MaxPool2D(pool_size=(2, 2)))
    model.add(Dropout(rate=0.25))

    model.add(Flatten())
    model.add(Dense(256, activation='relu'))
    model.add(Dropout(rate=0.5))
    model.add(Dense(10, activation='softmax'))

    datagen = tf.keras.preprocessing.image.ImageDataGenerator(
            rotation_range=10,
            zoom_range=0.1,
            width_shift_range=0.1,
            height_shift_range=0.1)

    model.compile(loss='categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy'])
    datagen.fit(X_train)
    history = model.fit(datagen.flow(X_train, y_train, batch_size=batch_size), epochs=epochs,
                                  validation_data=datagen.flow(X_test, y_test, batch_size=batch_size), verbose=2,
				                  steps_per_epoch=X_train.shape[0]//batch_size)

    model.save('mnist_model.h5')
    return model
    
# 載入模型
def loadModel():
    return tf.keras.models.load_model('mnist_model.h5')

接著,修改13_01_Canvas.py,在【辨識】按鈕的處理函數加上辨識功能:

    # 【辨識】處理函數
    def classify(self, widget):
        # self.image1.save('原圖.png')
        img = self.image1.resize((28, 28), ImageGrab.Image.ANTIALIAS).convert('L')
        # img.save('縮小.png')
        
        img = np.array(img)
        # Change pixels to work with our classifier
        img = (255 - img) / 255
        
        img2=Image.fromarray(img) 
        #img2.save('2.png')

        img = np.reshape(img, (1, 28, 28, 1))
        
        # Predict digit
        pred = model.predict([img])
        # Get index with highest probability
        pred = np.argmax(pred)
        #print(pred)
        self.prediction_text.delete("1.0", END)
        self.prediction_text.insert(END, pred)

在主程式加上訓練模型或載入既有的模型的處理:

if __name__ == '__main__':
    # 訓練模型或載入既有的模型
    if(os.path.exists('mnist_model.h5')):
        print('load model ...')
        model = loadModel()
    else:
        print('train model ...')
        X_train, y_train, X_test, y_test = getData()
        model = trainModel(X_train, y_train, X_test, y_test)

    print(model.summary())
    
    # 顯示視窗
    Paint()

檔案名稱為 13_02_CNN_model.py。

測試

執行下列指令測試:
python 13_02_CNN_model.py

隨便書寫幾個數字,辨識都正確,效果還不錯。

測試畫面如下:
https://ithelp.ithome.com.tw/upload/images/20200912/200019767i2jyecyCF.png
圖四. 視窗介面

結論

機器學習的套件提供內建的資料集,讓學習者容易上手,雖然立意良善,但也常造成錯誤認知,誤以為模型真如訓練結果的準確度那麼高,依筆者的經驗,還是要自備一些資料預測,才可以真正達到 Out of Samples 測試,在專案中,試著開發一個介面,不管是桌面程式或網頁程式,都是值得投資的項目。

本篇範例包括 13_01_Canvas.py、13_02_CNN_model.py、cnn_class.py,可自【這裡】下載。


上一篇
Day 12:影像資料增補(Data Augmentation)
下一篇
Day 14:預先訓練好的模型(Keras Applications)
系列文
輕鬆掌握 Keras 及相關應用30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言