我們現在已經學會了要怎麼印出照片,那我們要如何利用機器學習得到預測的數值呢?接下來的程式碼就是要來解釋這個部分。
def show_images_labels_predictions(images,labels,
predictions,start_id,num):
plt.gcf().set_size_inches(12, 14)
if num>25: num=25
for i in range(num):
ax=plt.subplot(5,5, i+1)
ax.imshow(images[start_id], cmap='binary')
title = 'label = ' + str(labels[start_id])
ax.set_title(title,fontsize=12)
ax.set_xticks([]);ax.set_yticks([])
start_id+=1
plt.show()
程式說明:
1~2:定義一個show_images_labels_predictioins函數帶有5個參數分別為images(數字圖片)、labels(自帶的真實值)、predictions(之後會解釋,但現在不會用到)、start_id(照片要從哪個開始顯示)、num(顯示照片數量)
3:設定數字圖片的長寬為12、14英吋
4:檢查num(顯示照片數量)是否超出最大值25,若超出則將num值設為25,以避免超出5*5子圖網格的範圍
5:建立了一個迴圈,從start_id開始,遞增到 start_id + num - 1,共執行num次。這決定了要在子圖網格中顯示的圖像範圍。
6:建立一個在5*5子圖網格且編號為i+1的對象ax
7:使用ax.imshow()函數來顯示數字圖片,第1個參數為images[start_id]為照片的索引值,第2個參數cmap=’binary’為設定圖片配色
9:設一個變數title,文字為label+數字圖片自帶的真實值,如label=5
10:設定title文字的字體大小為12
11:將x和y軸的刻度設置為空,以去除子圖的刻度標籤。
12:迴圈中的start_id每跑一次+1,用於不重複顯示下一張照片
13:顯示所有子圖
編輯完這個函數後我們就可以利用show()方法來呼叫剛剛寫好的程式碼了
show_images_labels_predictions(train_feature,train_label,[],0,10)