每次訓練模型十分的浪費時間,所以我們今天就來介紹怎麼儲存我們訓練好的模型吧。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
import numpy as np
from time import time
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
x=tf.placeholder("float",[None , 784])
W = tf.Variable(tf.random_normal([784,1000]))
b = tf.Variable(tf.random_normal([1,1000]))
XWb =tf.matmul(x,W) + b
XWb=tf.nn.relu(XWb)
W1 = tf.Variable(tf.random_normal([1000,10]))
b1 = tf.Variable(tf.random_normal([1,10]))
XWb1 =tf.matmul(XWb,W1) + b1
#h1=layer(out_dim=1000,in_dim=784,inputs=x, activation=tf.nn.relu)
#y=layer(out_dim=10,in_dim=1000,inputs=h1,activation=None)
y_label=tf.placeholder("float",[None,10])
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=XWb1,labels=y_label))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.argmax(y_label , 1), tf.argmax(XWb1,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction , "float"))
train = 15
batchSize= 110
total= 500
loss_list=[];epoch_list=[];accuracy_list=[]
starttime=time()
sess=tf.Session()
sess.run(tf.global_variables_initializer())
for epoch in range(train):
for i in range(total):
batch_x,batch_y=mnist.train.next_batch(batchSize)
sess.run(optimizer,feed_dict={x:batch_x , y_label:batch_y})
loss1,acc=sess.run([loss,accuracy], feed_dict={x:mnist.validation.images, y_label:mnist.validation.labels})
epoch_list.append(epoch)
loss_list.append(loss1)
accuracy_list.append(acc)
print("Train Epoch:",'%02d' % (epoch+1),"loss=","{:.9f}".format(loss1)," accuracy=",acc)
duration = time()-starttime
print("train finished takes:",duration)
plt.plot(epoch_list,accuracy_list,label="accuracy")
fig =plt.gcf()
fig.set_size_inches(4,2)
plt.ylim(0.8,1)
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend()
plt.show()
def plot_images_labels(images,labels,prediction,idx,num=10):
fig=plt.gcf()
fig.set_size_inches(12,14)
if num>25:
num=25
for i in range(0,num):
ax=plt.subplot(5,5,1+i)
ax.imshow(np.reshape(images[idx],(28,28)), cmap='binary')
title="label=" +str(np.argmax(labels[idx]))
if len(prediction)>0:
title+=",predict="+str(prediction[idx])
ax.set_title(title,fontsize=10)
ax.set_xticks([])
ax.set_yticks([])
idx=idx+1
plt.show()
以上跟DAY19-20的程式碼一樣,只是把他攤開而已,因為筆者使用副程式的方式建立變量,執行中遇到很多錯誤,這部份筆者還在研究,所以我們先使用可以使用的方法解決。
saver = tf.train.Saver()
saver.save(sess,"./model.ckp")
說明:
第1行:呼叫內建的tf.train.Saver函式
第2行:儲存目前會談,後面是放你想存的地點及檔名。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
import numpy as np
from time import time
def plot_images_labels(images,labels,prediction,idx,num=10):
fig=plt.gcf()
fig.set_size_inches(12,14)
if num>25:
num=25
for i in range(0,num):
ax=plt.subplot(5,5,1+i)
ax.imshow(np.reshape(images[idx],(28,28)), cmap='binary')
title="label=" +str(np.argmax(labels[idx]))
if len(prediction)>0:
title+=",predict="+str(prediction[idx])
ax.set_title(title,fontsize=10)
ax.set_xticks([])
ax.set_yticks([])
idx=idx+1
plt.show()
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
x=tf.placeholder("float",[None , 784])
W = tf.Variable(tf.random_normal([784,1000]))
b = tf.Variable(tf.random_normal([1,1000]))
XWb =tf.matmul(x,W) + b
XWb=tf.nn.relu(XWb)
W1 = tf.Variable(tf.random_normal([1000,10]))
b1 = tf.Variable(tf.random_normal([1,10]))
XWb1 =tf.matmul(XWb,W1) + b1
y_label=tf.placeholder("float",[None,10])
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=XWb1,labels=y_label))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.argmax(y_label , 1), tf.argmax(XWb1,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction , "float"))
以上跟DAY19-20的程式碼一樣,只是把他攤開並只留下定義模型的部份。
saver = tf.train.Saver([W,W1,b,b1])
sess=tf.Session()
saver.restore(sess,"./model.ckp")
prediction=sess.run(tf.argmax(XWb1,1),feed_dict={x:mnist.test.images})
plot_images_labels(mnist.test.images,mnist.test.labels,prediction,10)
說明:
第1行:載入變量
第2行:宣告會談
第3行:載入儲存的模型
第5-8行:跟之前一樣,進行預測
到了這邊,就可以使用之前訓練好的模型了,對於某些很大的模型,以及電腦效能比較差的人來說,十分的方便,而文章中遇到的問題,筆者會慢慢的去嘗試並且解決的,今天的文章就先到這邊。