iT邦幫忙

2022 iThome 鐵人賽

DAY 7
0
AI & Data

機器學習的 hello world - 用手寫數字辨識系統學習 ML 的 30 天系列 第 7

[DAY7] 怎麼找適合的神經網路方法?以機器學習找方法三步驟解釋

  • 分享至 

  • xImage
  •  

接著講怎麼找一個適合手寫數字辨識系統神經網路方法


機器學習找方法三步驟

機器學習而找適合函數或說問題解法可以簡單濃縮成三步驟:

  1. define a set of function:先找一個機器學習的模型(model)
  2. goodness of function:評估 model 的好壞
  3. pick the best function:挑一個最好的 model

神經網路找方法三步驟

身為機器學習方法之一的神經網路(Neural Network, 以下簡稱NN)也是一樣,如果以三步驟來找一個神經網路方法的話:

  1. 定義你的神經網路結構(NN structure)
  2. 找一個可以判斷你的 model 目前狀況與目標差距的指標,譬如說用 MSE(mean square error) 這種方法去算 loss function
  3. 接下來用倒傳遞學習法 (gradient descent, 梯度下降演算法) 去挑一個最好的model (optimization)

所以對手寫數字辨識系統來說,當確定實作環境與輸入圖後,我們會先建一個 NN model,訓練好並用相關指標或正確率,確定這個 NN model 就是我最想要的後,再拿這個 model 做數字預測。


三步驟展示

這邊以 Keras 套件拼出簡單的 NN 架構,展示找方法三步驟中每一個步驟的程式實例大概長怎樣,裡面的細節、意義及完整拼出一個以 NN 作為辨識方法的手寫辨識系統會再之後說。

1. 找一個 model / 建立一個 NN structure

定義我要疊一個沒有回饋輸入的線性執行模型(sequential model),確定input格式、我要疊幾層、每層要有幾個節點,並將每層輸出做一個非線性的轉換(activation function)

# 定義此 model 模式為 sequential model
model = Sequential()  
# 建 model,output 10 類
model.add(Dense(16, input_dim=784, kernel_initializer='normal', activation='relu'))  
model.add(Dense(10, activation='softmax'))

在這邊我們可以用 print(model.summary()) 去看此 model 架構(下圖左) ,或用圖示(下圖右)來幫助理解。

https://ithelp.ithome.com.tw/upload/images/20220921/20131719CI6rmy6RZV.png


2. 定義 model 的好壞並選最好的

主要會以 model.compile 定義 model 執行時的損失/誤差函數(loss)優化/最佳化方式(optimizer),及成效衡量/結果評估指標(mertrics)

  • loss:損失/誤差函數,跟我們期望的差多少,譬如說用 MSE(mean square error)回饋給函數修正
  • optimizer:優化/最佳化方式,你要用什麼方法修正,譬如說用倒傳遞演算法/梯度下降演算法+學習速率等
  • metrics成效衡量/結果評估指標,譬如說正確率,跟損失函數有點像,但此結果不會用在函數訓練過程中

然後用 model.fit模型訓練,並設定要訓練幾次、每幾筆資料為一組,計算它們的 loss 再去更新權重等,有些人會在這邊切一部份出來驗證。

# model.compile
model.compile(loss='categorical_crossentropy', 
optimizer='adam', 
metrics=['accuracy'])

# 訓練
train_history = model.fit(x_train, y_train, epochs=10, batch_size=50) #validation_split=0.2

執行過程會在 console 中顯示(譬如下圖左),在這邊也可以藉 python 的 matplotlib 套件去畫出各種圖(如下圖右)。

https://ithelp.ithome.com.tw/upload/images/20220921/20131719yQ4bfIph2L.png


ps. 測試/ 預測

然後我們可以拿這訓練好的模型去做測試model.evaluate預測model.predict,並根據結果推測原因、做進一步 training 部分或testing 部分的修改等。


今天我們講了找方法的三步驟,由於提到的專有名詞比較多,大家可以先對步驟有概念就好。接下來我們會一一聊裡面細節,了解前置作業如相關實作環境和資料集,再花一天時間把所有程式組起來,把這些細項拼起來後,再來看若效果達不到預期,可以怎麼調整。大家明天見XD

[註1]三步驟概念來自李宏毅老師機器學習課程


上一篇
[DAY6] 機器學習學什麼?釐清手寫數字辨識系統的學習情境與類型
下一篇
[DAY8] 讓 NN model 引入非線性-激勵函數(activation function)
系列文
機器學習的 hello world - 用手寫數字辨識系統學習 ML 的 30 天30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言