第 16 屆 iThome 鐵人賽 (2023)
{%hackmd BJrTq20hE %}
剛踏入這個領域,很多人都以MNIST資料庫當作小試身手,可以說是machine learning 或 deep learning 的 hello world!那就來玩吧~
資料庫介紹
這是一個大型手寫數字資料庫,經常被使用在機器學習便是的領域,每一張圖片為 28*28大小,這個數據庫當中包含60000筆訓練影像和10000筆測試影像。
從keras 也有接口可以下載MNIST的資料,現在就來實際載資料,以及了解其資料型態,呈現方式。
# 1.匯入 Keras 及相關模組
import numpy as np
import pandas as pd
from keras.utils import np_utils
# 用來後續將 label 標籤轉為 one-hot-encoding
np.random.seed(10)
# 2.下載 mnist data
from keras.datasets import mnist
# 3.讀取與查看 mnist data
(X_train_image, y_train_label), (X_test_image, y_test_label) = mnist.load_data()
print("\t[Info] train data={:7,}".format(len(X_train_image)))
print("\t[Info] test data={:7,}".format(len(X_test_image)))
輸出:
>>>[Info] train data= 60,000
>>>[Info] test data= 10,000
# 1.訓練資料是由 images 與 labels 所組成
print("\t[Info] Shape of train data=%s" % (str(X_train_image.shape)))
print("\t[Info] Shape of train label=%s" % (str(y_train_label.shape)))
# 得:訓練資料是由 images 與 labels 所組成共有 60,000 筆, 每一筆代表某個數字的影像為 28x28 pixels.
# 2.建立 plot_image 函數顯示數字影像
import matplotlib.pyplot as plt
def plot_image(image):
fig = plt.gcf()
fig.set_size_inches(2,2)
plt.imshow(image, cmap='binary') # cmap='binary' 參數設定以黑白灰階顯示.
plt.show()
# 3.執行 plot_image 函數查看第 0 筆數字影像與 label 資料
plot_image(X_train_image[0])
print(y_train_label[0])
# 得:呼叫 plot_image 函數, 傳入 X_train_image[0], 也就是順練資料集的第 0 筆資料, 顯示結果可以看到這是一個數字 5 的圖形
輸出:
>>>[Info] Shape of train data=(60000, 28, 28)
>>>[Info] Shape of train label=(60000,)
# 1.建立 plot_images_labels_predict() 函數
# 為了後續能很方便查看數字圖形, 真實的數字與預測結果
def plot_images_labels_predict(images, labels, prediction, idx, num=10):
fig = plt.gcf()
fig.set_size_inches(12, 14)
if num > 25: num = 25
for i in range(0, num):
ax=plt.subplot(5,5, 1+i)
ax.imshow(images[idx], cmap='binary')
title = "l=" + str(labels[idx])
if len(prediction) > 0:
title = "l={},p={}".format(str(labels[idx]), str(prediction[idx]))
else:
title = "l={}".format(str(labels[idx]))
ax.set_title(title, fontsize=10)
ax.set_xticks([]); ax.set_yticks([])
idx+=1
plt.show()
plot_images_labels_predict(X_train_image, y_train_label, [], 0, 10)
輸出:
>>>
https://github.com/bubbliiiing/yolov4-pytorch