iT邦幫忙

2018 iT 邦幫忙鐵人賽
DAY 9
0
AI & Machine Learning

探索 Microsoft CNTK 機器學習工具系列 第 9

MNIST 手寫數字資料集

  • 分享至 

  • xImage
  •  

Introduction

MNIST 是一個手寫數字的圖像資料集,我們將用於圖像辨識的範例之中。

MNIST 資料集,可至 THE MNIST DATABASE下載。

MNIST 資料集常用於機器學習訓練和測試的教學。
資料集包含 60000 個訓練圖片和 10000 個測試圖片,每個圖片大小是 28 * 28 像素。

Tasks

學習資源:cntk\Tutorials\CNTK_103A_MNIST_DataLoader.ipynb
下載網路資料,預先資料處哩,儲存於本地端資料夾。

引用相關組件

# 引用相關組件
# 相容性需求,若使用舊版pyton時,可使用新版python函式
from __future__ import print_function
import gzip
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import os
import shutil
import struct
import sys

try: 
    from urllib.request import urlretrieve 
except ImportError: 
    from urllib import urlretrieve

# 設定繪圖組件繪圖在當前界面 - Jupyter Notebook
%matplotlib inline

1.資料讀取(Data reading):

讀取 MNIST 資料集,將資料解壓縮,並分為訓練資料集及測試資料集

宣告函式:讀取 MNIST 並格式化為 28 * 28 陣列

def loadData(src, cimg):
    print ('Downloading ' + src)
    gzfname, h = urlretrieve(src, './delete.me')
    print ('Done.')
    
    try:
        with gzip.open(gzfname) as gz:
            n = struct.unpack('I', gz.read(4))
            
            # 判斷是否為 MNIST 資料集
            if n[0] != 0x3080000:
                raise Exception('Invalid file: unexpected magic number.')
                
            # 計算資料筆數
            n = struct.unpack('>I', gz.read(4))[0]
            
            if n != cimg:
                raise Exception('Invalid file: expected {0} entries.'.format(cimg))
                
            crow = struct.unpack('>I', gz.read(4))[0]
            ccol = struct.unpack('>I', gz.read(4))[0]
            
            if crow != 28 or ccol != 28:
                raise Exception('Invalid file: expected 28 rows/cols per image.')
                
            # 讀取資料
            res = np.fromstring(gz.read(cimg * crow * ccol), dtype = np.uint8)
    finally:
        os.remove(gzfname)
    return res.reshape((cimg, crow * ccol))

宣告函式:loadLabels 讀取每個圖像對應的標籤資料

def loadLabels(src, cimg):
    print ('Downloading ' + src)
    gzfname, h = urlretrieve(src, './delete.me')
    print ('Done.')
    try:
        with gzip.open(gzfname) as gz:        
            n = struct.unpack('I', gz.read(4))
            
            # 判斷是否為 MNIST 資料集
            if n[0] != 0x1080000:
                raise Exception('Invalid file: unexpected magic number.')
                
            # 計算資料筆數
            n = struct.unpack('>I', gz.read(4))
            
            if n[0] != cimg:
                raise Exception('Invalid file: expected {0} rows.'.format(cimg))
                
            # 讀取資料
            res = np.fromstring(gz.read(cimg), dtype = np.uint8)
    finally:
        os.remove(gzfname)
    return res.reshape((cimg, 1))

宣告函式:將圖像及標籤資料組合,水平(horizontally)展開為向量

def try_download(dataSrc, labelsSrc, cimg):

    data = loadData(dataSrc, cimg)
    labels = loadLabels(labelsSrc, cimg)
    
    return np.hstack((data, labels))

資料下載

# 訓練資料集:圖像及標籤資料
url_train_image = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
url_train_labels = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'
num_train_samples = 60000

print("Downloading train data")
train = try_download(url_train_image, url_train_labels, num_train_samples)

# 測試資料集:圖像及標籤資料
url_test_image = 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'
url_test_labels = 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'
num_test_samples = 10000

print("Downloading test data")
test = try_download(url_test_image, url_test_labels, num_test_samples)

資料視覺化

# 從訓練資料集中,亂數挑選一個圖像資料,繪製圖像及印出對應的標籤資料
sample_number = 5001

plt.imshow(train[sample_number,:-1].reshape(28,28), cmap="gray_r")
plt.axis('off')

print("Image Label: ", train[sample_number,-1])

2.資料處理(Data preprocessing):

圖像資料:
將資料下載後儲存於在本地端資料夾,本地文件夾中保存圖片:圖像儲存為 28 * 28 = 784 長度的向量。
MNIST

標籤資料:
使用有效編碼(One-Hot Encoding)編碼格式儲存,用於表示名目尺度的一種編碼格式,使用時較有效率。
One-Hot Encoding

宣告函式:savetxt 將資料儲存為 CNTK 相容的資料格式

def savetxt(filename, ndarray):
    dir = os.path.dirname(filename)

    if not os.path.exists(dir):
        os.makedirs(dir)

    if not os.path.isfile(filename):
        print("Saving", filename )
        
        with open(filename, 'w') as f:
        
            labels = list(map(' '.join, np.eye(10, dtype=np.uint).astype(str)))
            
            for row in ndarray:
                row_str = row.astype(str)
                label_str = labels[row[-1]]
                feature_str = ' '.join(row_str[:-1])
                f.write('|labels {} |features {}\n'.format(label_str, feature_str))
    else:
        print("File already exists", filename)

呼叫 savetxt 函式將資料儲存於本地端。
cntk\Examples\Image\DataSets\MNIST

data_dir = os.path.join("..", "Examples", "Image", "DataSets", "MNIST")
if not os.path.exists(data_dir):
    data_dir = os.path.join("data", "MNIST")

print ('Writing train text file...')
savetxt(os.path.join(data_dir, "Train-28x28_cntk_text.txt"), train)

print ('Writing test text file...')
savetxt(os.path.join(data_dir, "Test-28x28_cntk_text.txt"), test)

print('Done')

上一篇
前饋神經網路
下一篇
MNIST 手寫數字資料集:邏輯迴歸模型
系列文
探索 Microsoft CNTK 機器學習工具30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言