iT邦幫忙

2018 iT 邦幫忙鐵人賽
DAY 4
2
AI & Machine Learning

以100張圖理解 Neural Network -- 觀念與實踐系列 第 4

Day 04:關於 Keras 的一些小技巧 -- 組態、模型存檔與實驗

前言

再往下探究之前,我們輕鬆一點,先作點實驗,驗證上上篇的程式辨識準確率是否真的那麼高? 可否在應用系統上使用? 譬如,阿拉伯數字辨識率如果那麼高,我們是否可以提供手寫板,讓用戶直接輸入,用於輸入密碼、開鎖、填寫問卷、考試答題...等等。

另外,在實驗之前,我們先討論一些 Keras 小技巧,讓我們在開發程式時更有效率,包括:

  1. 模型存檔
  2. Keras 組態
  3. 資料集(Datasets)
  4. Keras事先訓練好的應用程式(Applications)

模型存檔(Persistence)

模型訓練完畢後,結果如可接受,可以將模型存檔,下次要再測試時,就可直接載入,不需重新訓練,模型的資訊包括結構及訓練出來的權重(W)。

  1. 模型結構存檔:以下程式將結構存到 model.config 檔案,檔案為JSON或YAML格式。
from keras.models import model_from_json
json_string = model.to_json() with open("model.config", "w") as text_file:    
text_file.write(json_string)
  1. 權重(W)存檔:以下程式將權重存到 model.weight 檔案。
model.save_weights("model.weight")
  1. 同時儲存結構與權重,檔案的類別為HDF5。
from keras.models import load_model

model.save('model.h5')  # creates a HDF5 file 'model.h5'

模型載入

之後,我們要使用時,可輸入下列程式碼,載入模型結構及權重(W)。

import numpy as np  
from keras.models import Sequential
from keras.models import model_from_json
with open("model.config", "r") as text_file:
    json_string = text_file.read()
    model = Sequential()
    model = model_from_json(json_string)
    model.load_weights("model.weight", by_name=False)

或者直接載入HDF5檔案

from keras.models import load_model

# 刪除既有模型變數
del model 

# 載入模型
model = load_model('my_model.h5')

Keras 組態

  1. Keras 組態檔名稱為 keras.json,會儲存在使用者資料夾下的『.keras』子目錄。
  2. 如果你下載Keras事先訓練好的應用程式(Applications),它就會放在使用者資料夾下的『.keras\models』子目錄。
  3. 如果你下載Keras的資料集(Datasets),例如,之前程式下載 MNIST 阿拉伯數字資料集,它就會放在使用者資料夾下的『.keras\datasets』子目錄。不用 (X_train, y_train), (X_test, y_test) = mnist.load_data(),要直接開啟檔案,程式碼如下:
f = np.load(get_file("mnist.npz", origin="~/.keras"))
x_train = f['x_train']
y_train = f['y_train']
x_test = f['x_test']
y_test = f['y_test']
f.close()

如果,直接從網路下載,可改為

f = np.load(get_file("mnist.npz", origin="https://s3.amazonaws.com/img-datasets/mnist.npz"))

Keras事先訓練好的應用程式(pre-trained Applications)

Keras提供幾個事先訓練好的經典應用程式,不必重新訓練,可直接套用,請參考官方文件,使用方法如下:

from keras.applications.vgg16 import VGG16
from keras.applications.vgg16 import preprocess_input
import numpy as np

model = VGG16(weights='imagenet', include_top=False)

官方文件找不到詳細用法,我花費好一番功夫才弄懂,後面談到 CNN 會詳細介紹,敬請期待。

資料集(Datasets)

Keras提供幾個現成的資料集,可作為訓練/測試資料,,請參考官方文件,包括手寫數字、分類圖片、影評、新聞、... 等。也可以自其他網站下載,例如,你覺得辨識0~9不過癮,也想辨識 A~Z, a~z,可至這裡下載。

實驗

我用C#寫了一個Draw.exe 小程式, Source Code 放在這裡,可以使用滑鼠,書寫數字,並將它存成與MNIST類似的格式(.csv),再用Python程式載入,依照訓練出來的模型測試是否可以辨識,步驟如下:

  1. 執行 Draw.exe,書寫 0~9,並存成 0.csv, 1.csv ..., 9.csv。
  2. 在DOS下,執行 python 0_1.py,假設此程式與*.csv放在同目錄。
  3. 可以看到10個數字的辨識結果,如果都正確,那就恭喜你了。

https://ithelp.ithome.com.tw/upload/images/20171215/20001976FnhgIqTYRO.jpg
圖. 手寫數字 9 的比較,左為 MNIST, 右為筆者以 Draw.exe 手寫的數字

結論

筆者反覆測試多次,發覺測試結果並不如MNIST測試資料那麼準確,正所謂『盡信書,不如無書』,原因如下:

  1. 觀察原圖,MNIST資料應是請測試者在紙上書寫,再掃描進電腦,因為,原檔如下,有毛邊。
  2. 使用Draw.exe書寫與MNIST有差異,造成辨識率不佳。

https://ithelp.ithome.com.tw/upload/images/20181012/200019769meIAgcAYM.png

另外,訓練出來的準確率均達85%,甚至95%,乍看很高,但仔細想想,如果是應用在銀行存款數目的辨識,使用者輸入10位數,只要一個數字錯,銀行老董可能就要崩潰了,反之,用在遊戲中,使用者可能會讚聲連連,驚嘆不已,所以,Machine Learning 的應用還是必須考量使用的時機與應用場域,才能贏得掌聲。


上一篇
Day 03:Neural Network 的概念探討
下一篇
Day 05:Keras 模型、函數及參數使用說明
系列文
以100張圖理解 Neural Network -- 觀念與實踐31
0
Arsene
iT邦新手 5 級 ‧ 2018-04-11 14:54:31

之前有試跑過MNIST數字辨識, 也是覺得它很厲害
感謝提供 Draw.exe
而我自己寫的數字, 用之前MNIST訓練出的模型, 辨識率不到一半..

試著找原因
可能是因轉成 28*28的圖,
轉換上有些失真,
比如我的0, 把它印出來, 就變這樣
https://ithelp.ithome.com.tw/upload/images/20180411/20109368KEucYSf62V.png
它就被辨識為 9 了

是,你說對了。
MNIST 的圖片應該是請測試者在紙上寫下來,再經掃描,所以,放大看如下,筆劃的寬度不固定且無鋸齒狀,與Draw.exe 不同,所以,辨識率較差。
https://ithelp.ithome.com.tw/upload/images/20180411/20001976oTDinJFfmu.png

0
gavinsu
iT邦見習生 0 級 ‧ 2018-10-10 22:43:35

實驗
我用C#寫了一個Draw.exe 小程式, Source Code 放在這裡,可以使用滑鼠,書寫數字,並將它存成與MNIST類似的格式(.csv),再用Python程式載入...

請問老師要如何用Python程式載入CSV?試很久都沒辦法~

請問你是哪一班的同學?
我把程式放在 Google Drive,會保留三天。
請參考 0_1.py

for i in range(0, 10):
    X2 = np.genfromtxt('./'+str(i)+'.csv', delimiter=',').astype('float32')  
    X1 = X2.reshape(1,28*28) / 255
    predictions = model.predict_classes(X1)
    # get prediction result
    print(predictions)
0
gavinsu
iT邦見習生 0 級 ‧ 2018-10-10 23:45:56

報告老師, 執行會有錯誤 (有將*.csv跟0_1.py放在同一檔案夾)

C:\Users\pp>C:\Users\pp\Desktop\0_1\0_1.py
C:\Users\pp\Anaconda3\lib\site-packages\h5py_init_.py:36: FutureWarning: Conversion of the second argument of issubdtype from float to np.floating is deprecated. In future, it will be treated as np.float64 == np.dtype(float).type.
from ._conv import register_converters as _register_converters
Using TensorFlow backend.
Traceback (most recent call last):
File "C:\Users\pp\Desktop\0_1\0_1.py", line 14, in
X2 = np.genfromtxt('./'+str(i)+'.csv', delimiter=',').astype('float32')
File "C:\Users\pp\Anaconda3\lib\site-packages\numpy\lib\npyio.py", line 1689, in genfromtxt
fhd = iter(np.lib._datasource.open(fname, 'rt', encoding=encoding))
File "C:\Users\pp\Anaconda3\lib\site-packages\numpy\lib_datasource.py", line 260, in open
return ds.open(path, mode, encoding=encoding, newline=newline)
File "C:\Users\pp\Anaconda3\lib\site-packages\numpy\lib_datasource.py", line 616, in open
raise IOError("%s not found." % path)
OSError: ./0.csv not found.

抱歉,少放了0.csv,已補上。你可以使用 Google Drive 的 draw/draw.exe 寫 0~9,存檔後產生的.csv 放在 01.py 的相同目錄,再執行 01.py。

0
gavinsu
iT邦見習生 0 級 ‧ 2018-10-11 20:49:32

報告老師, 已解決問題, 實驗結果辨識度很不理想, 請問要多訓練嗎?
以下為實驗結果辨識度:
[5]
[1]
[2]
[3]
[9]
[9]
[6]
[7]
[5]
[9]

請參考本文的結論。

0
gavinsu
iT邦見習生 0 級 ‧ 2018-10-13 08:29:54

謝謝老師的指點,感恩~

/images/emoticon/emoticon08.gif

0
funpi89
iT邦新手 5 級 ‧ 2018-11-05 01:18:07

想請問一下我每次修改完超參數執行0.py後都會出現"GPU Sync fail"的訊息,然後就要重開機才可以執行,執行玩0.py後要執行01.py時卻又出現同樣的錯誤以致又要重開機,該怎麼做才不會發生這種一直需要重開機的情況

我沒有碰過這種情形,可以參考以下討論:
https://stackoverflow.com/questions/51112126/gpu-sync-failed-while-using-tensorflow
https://github.com/tensorflow/tensorflow/issues/1450
https://github.com/tensorflow/tensorflow/issues/4425

綜合來看,應該是 Cuda/cuDNN 安裝或是特定GPU的問題,例如 GTX 950M。

我要留言

立即登入留言