iT邦幫忙

2018 iT 邦幫忙鐵人賽
DAY 23
0
AI & Machine Learning

tensorflow python系列 第 23

DAY23 使用CNN增加手寫數字辨識準確度(2)

前言:

今天我們將使用上篇文章中所建立的模型,來進行訓練,並且查看結果。

程式開始:

(1)定義訓練方式

with tf.name_scope('optimizer'):
    y_label=tf.placeholder("float",[None,10], name='y_label')

    loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y,labels=y_label))

    optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

說明:

跟DAY20的程式一模一樣。

(2)定義評估模型

with tf.name_scope('eva'):
    correct_prediction = tf.equal(tf.argmax(y_label , 1), tf.argmax(y,1))

    accuracy = tf.reduce_mean(tf.cast(correct_prediction , "float"))

說明:

跟DAY20的程式一模一樣。

(3)進行訓練

train = 3
batchSize= 100
total= 550
loss_list=[];epoch_list=[];accuracy_list=[]
starttime=time()

sess=tf.Session()
sess.run(tf.global_variables_initializer())
for epoch in range(train):
    for i in range(total):
        batch_x,batch_y=mnist.train.next_batch(batchSize)
        sess.run(optimizer,feed_dict={x:batch_x , y_label:batch_y})
    loss1,acc=sess.run([loss,accuracy], feed_dict={x:mnist.validation.images, y_label:mnist.validation.labels})
    epoch_list.append(epoch)
    loss_list.append(loss1)
    accuracy_list.append(acc)
    print("Train Epoch:",'%02d' % (epoch+1),"loss=","{:.9f}".format(loss1)," accuracy=",acc)
duration = time()-starttime
print("train finished takes:",duration)

說明:

跟DAY20的程式一模一樣但是周期下降,因為電腦有點不夠力跑的有點慢。

(4)畫出圖形

def plot_images_labels(images,labels,prediction,idx,num=10):
    fig=plt.gcf()
    fig.set_size_inches(12,14)
    if num>25:
        num=15
    for i in range(0,num):
        ax=plt.subplot(5,5,1+i)
        ax.imshow(np.reshape(images[idx],(28,28)), cmap='binary')
        title="label=" +str(np.argmax(labels[idx]))
        if len(prediction)>0:
            title+=",predict="+str(prediction[idx])
        ax.set_title(title,fontsize=10)
        ax.set_xticks([])
        ax.set_yticks([])
        idx=idx+1
    plt.show()

說明:

跟DAY20的程式一模一樣。

(5)進行預測

prediction=sess.run(tf.argmax(y,1),feed_dict={x:mnist.test.images})

#prediction[:10]

plot_images_labels(mnist.test.images,mnist.test.labels,prediction,100)

說明:

跟DAY20的程式一模一樣。

(6)TensorBoard

merged=tf.summary.merge_all()
train_writer = tf.summary.FileWriter('log/CNN',sess.graph)

說明:

將log檔寫入jupyter notebook當下執行資料夾底下的log/CNN。

TensorBoard的用法於DAY12有教學過了,這邊就不多加贅述。

成果:

https://ithelp.ithome.com.tw/upload/images/20180116/201075358AsLaf0MwK.png

https://ithelp.ithome.com.tw/upload/images/20180116/20107535rNxscpoz3q.png

結語:

這樣就完成了CNN的部份了,Tensorflow也大概到了一個段落了,從下一篇文章開始,會開始介紹Keras這個高階的深度學習程式庫,後端仍然使用Tensorflow,程式語言也依然使用python,所以讀者不需要太過擔心。


上一篇
DAY22 使用CNN增加手寫數字辨識準確度(1)
下一篇
DAY24 Keras安裝
系列文
tensorflow python30

尚未有邦友留言

立即登入留言