iT邦幫忙

2018 iT 邦幫忙鐵人賽
DAY 10
0
AI & Machine Learning

探索 Microsoft CNTK 機器學習工具系列 第 10

MNIST 手寫數字資料集:邏輯迴歸模型

Introduction

我們使用 MNIST 資料集訓練一個邏輯迴歸模型。

MNIST 資料集是由手書圖像和背景噪音組成,常用於機器學習訓練和測試的教學。

光學字元辨識(Optical Character Recognition, OCR)應用廣泛,目前技術相當成熟,數位出版有時甚至倚靠 OCR 掃描實體出版品來產生電子書。

Logistic regression

邏輯迴歸(Logistic regression)是機器學習的常用方法,使用特徵參數權重的線性組合,來辨識樣本所屬類別的可能性。

邏輯迴歸(Logistic regression)有兩種模型:

  • 二元邏輯模型(Binary Logit Model),一個輸出神經元,以預測兩個類別。
    每個輸入的特徵(feature)參數都乘以一個對應的權重(weight),將所有結果相加後的總合,經過激活函數(activate function) Sigmoid 函數,生成 0 到 1 之間的機率(probability)值,與設定的門檻值(threshold)比較後,判定其標籤為 0 或 1 分成 2 類。
  • 多元邏輯斯迴歸(Multiple logistic regressio),多個輸出神經元,每個神經元預測一個類別。
    每個輸入的特徵(feature)參數都乘以一個對應的權重(weight),將所有結果相加後的總合,經過 Softmax 函式歸一化。

Logistic regression

Tasks

學習資源:cntk\Tutorials\CNTK_103B_MNIST_LogisticRegression.ipynb
使用 MNIST 資料集訓練一個邏輯迴歸模型,將圖像分成 10 個分類,即 0 ~ 9。

引用相關組件

# 引用相關組件
# 相容性需求,若使用舊版pyton時,可使用新版python函式
from __future__ import print_function 
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import sys
import os

import cntk as C
import cntk.tests.test_utils

# 測試並設定使用 CPU 或 GPU 作為目前測試環境
cntk.tests.test_utils.set_device_from_pytest_env() 
# 重新設定 CNTK 的亂數種子
C.cntk_py.set_fixed_random_seed(1)

# 設定繪圖組件繪圖在當前界面 - Jupyter Notebook
%matplotlib inline

1.資料讀取(Data reading):

MNIST 資料集常用於機器學習訓練和測試的教學。
資料集包含 60000 個訓練圖片和 10000 個測試圖片,每個圖片大小是 28 * 28 像素。
圖像儲存為 28 * 28 = 784 長度的向量。
分類為 10 類,數字 0 ~ 9。

定義資料維度

# 資料特徵:設定為 784 個輸入變數
input_dim = 784
# 資料標籤:設定為 10 個輸出變數
num_output_classes = 10

宣告函式:create_reader 讀取訓練資料集和測試資料集,資料檔案是 CNTK 的資料格式(CNTK text-format, CTF),將資料檔案反序列化(deserializer)

def create_reader(path, is_training, input_dim, num_label_classes):
    
    labelStream = C.io.StreamDef(field='labels', shape=num_label_classes, is_sparse=False)
    featureStream = C.io.StreamDef(field='features', shape=input_dim, is_sparse=False)
    
    deserailizer = C.io.CTFDeserializer(path, C.io.StreamDefs(labels = labelStream, features = featureStream))
            
    return C.io.MinibatchSource(deserailizer,
       randomize = is_training, max_sweeps = C.io.INFINITELY_REPEAT if is_training else 1)

讀取本地端資料檔案。
cntk\Examples\Image\DataSets\MNIST

data_found = False

for data_dir in [os.path.join("..", "Examples", "Image", "DataSets", "MNIST"),
                 os.path.join("data", "MNIST")]:
    train_file = os.path.join(data_dir, "Train-28x28_cntk_text.txt")
    test_file = os.path.join(data_dir, "Test-28x28_cntk_text.txt")
    if os.path.isfile(train_file) and os.path.isfile(test_file):
        data_found = True
        break
        
if not data_found:
    raise ValueError("Please generate the data by completing CNTK 103 Part A")
    
print("Data directory is {0}".format(data_dir))

2.資料處理(Data preprocessing):

3.建立模型(Model creation):

邏輯迴歸(Logistic regression)是機器學習的常用方法,使用特徵參數權重的線性組合,來辨識樣本所屬類別的可能性。

邏輯回歸:
計算一個樣本的評估值
z:評估值,輸出值,代表每個類別的機率,10 個分類值的機率加總為 1。

輸入圖像的每個像素特徵(feature),各有一個對應的權重(weight),經過 softmax 函數進行歸一化。
x:表示輸入特徵向量,MNIST 圖像的像素值,用來描述我們需要分類的樣本
W:權重矩陣,大小是 10 * 784
b:是長度為 10 的偏移量向量,每個對應一個輸出數字

Logistic regression

設定輸入

# 資料特徵:設定為 2 個輸入變數
# 資料標籤:設定為 2 個輸出變數
input = C.input_variable(input_dim)
label = C.input_variable(num_output_classes)

宣告函式:create_model,CNTK layers 模組提供 Dense 函數,以建立全連接層(Fully connected layer)

def create_model(features):
    with C.layers.default_options(init = C.glorot_uniform()):
        r = C.layers.Dense(num_output_classes, activation = None)(features)
        return r

z:評估值

# 將每個像素除以 255 ,將輸入轉換成 0 ~ 1 之間的機率值。
z = create_model(input/255.0)

4.訓練模型(Learning the model):

使用 softmax 函式歸一化。

loss = C.cross_entropy_with_softmax(z, label)

5.評估模型(Evaluation):

評估模型,比較輸出值和標籤值。

label_error = C.classification_error(z, label)

Hyperparameters

超參數(hyperparameters)是在訓練模型前設置的參數,需要人工參與在每次的訓練開始前加以調校,以比較進而提高模型效能。

隨機梯度下降算法(Gradient descent)是常見的最佳化演算法。計算成本函數(cost function),並產生一個新的參數,以進入下一個遞迴。


上一篇
MNIST 手寫數字資料集
下一篇
MNIST 手寫數字資料集:多層感知器
系列文
探索 Microsoft CNTK 機器學習工具30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言