iT邦幫忙

2018 iT 邦幫忙鐵人賽
DAY 20
0
AI & Machine Learning

tensorflow python系列 第 20

DAY20 手寫數字辨識(2)

  • 分享至 

  • xImage
  •  

前言:

今天我們將使用上篇文章中所建立的模型,來進行訓練,並且查看結果。

程式開始:

(1)定義訓練方式

y_label=tf.placeholder("float",[None,10])

說明:

因為輸入的圖片數量不確定,所以第1維度設定為None,而因為第2個維度對應到0~9共10個數字所以設定為10。

(2)定義loss function

loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y,labels=y_label))

說明:

使用(https://www.tensorflow.org/api_docs/python/tf/nn/softmax_cross_entropy_with_logits
這個方法來當作損失函數。

(3)定義optimizer

optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

說明:

使用
https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer
來當作優化方法。

(4)定義評估模型

correct_prediction = tf.equal(tf.argmax(y_label , 1), tf.argmax(y,1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction , "float"))

說明:

第1行:比對兩個矩陣內資料是否相同。

第2行:將上列預測正確結果平均。

(5)進行訓練

train = 15
batchSize= 110
total= 500
loss_list=[];epoch_list=[];accuracy_list=[]
starttime=time()
sess=tf.Session()
sess.run(tf.global_variables_initializer())

for epoch in range(train):
    for i in range(total):
        batch_x,batch_y=mnist.train.next_batch(batchSize)
        sess.run(optimizer,feed_dict={x:batch_x , y_label:batch_y})
    loss1,acc=sess.run([loss,accuracy], feed_dict={x:mnist.validation.images, y_label:mnist.validation.labels})
    epoch_list.append(epoch)
    loss_list.append(loss1)
    accuracy_list.append(acc)
    print("Train Epoch:",'%02d' % (epoch+1),"loss=","{:.9f}".format(loss1)," accuracy=",acc)
duration = time()-starttime
print("train finished takes:",duration)

plt.plot(epoch_list,accuracy_list,label="accuracy")
fig =plt.gcf()
fig.set_size_inches(4,2)
plt.ylim(0.8,1)
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend()
plt.show()

說明:

第1行:train為定義訓練周期數。

第2行:batchSize為定義每次筆數。

第3行:訓練總筆數為55000,所以當每次為110筆的情況下,要進行500次。

第4行:初始化訓練周期,誤差,準確機率等等。

第5行:計算執行時間。

第6~7行:定義會談以及初始化。

第9~12行:開始進行訓練,訓練周期為15次每一周期為500批次。

第13行:使用驗證資料計算準確機率以及誤差。

第14~17行:顯示訓練結果,並且存入list。

第18~19行:計算並顯示訓練時間。

第21~28行:畫出準確機率圖表

(6)畫出圖形

def plot_images_labels(images,labels,prediction,idx,num=10):
    fig=plt.gcf()
    fig.set_size_inches(12,14)
    if num>25:
        num=25
    for i in range(0,num):
        ax=plt.subplot(5,5,1+i)
        ax.imshow(np.reshape(images[idx],(28,28)), cmap='binary')
        title="label=" +str(np.argmax(labels[idx]))
        if len(prediction)>0:
            title+=",predict="+str(prediction[idx])
        ax.set_title(title,fontsize=10)
        ax.set_xticks([])
        ax.set_yticks([])
        idx=idx+1
    plt.show

說明:

第1行:傳入數字影像,真實值,預測結果,開始顯示資料的開頭,要顯示的筆數等資料。

第2~3行:設定顯示圖形的大小。

第4~5行:當顯示筆數參數大於25設定為25。

第6~15行:執行for迴圈,畫出num個數字。

第16行:開始畫圖。

(7)進行預測

prediction=sess.run(tf.argmax(y,1),feed_dict={x:mnist.test.images})

plot_images_labels(mnist.test.images,mnist.test.labels,prediction,0)

說明:

第1行:執行預測。
第3行:顯示前10筆預測結果

https://ithelp.ithome.com.tw/upload/images/20180115/20107535QsvUzlVUP2.png

結語:

我們使用多層感知器來進行手寫數字的辨識,再進行15次周期隱藏層神經元1000的情況下,準確率大概是0.95左右,再下篇開始會加入CNN來增加預測的準確機率。


上一篇
DAY19 手寫數字辨識(1)
下一篇
DAY21 CNN概念
系列文
tensorflow python30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言