不管你用什麼訓練框架,終會遇到需要把過程或結果存起來的需求,一般來說,tensorflow 跟存檔有關的主要有兩種方式,一個是存成 checkpoint 檔,另一個是存成 pb 檔,這兩者在不同的場合會有不同用途,間單來說,如果你要保存目前session狀態,那我建議存成checkpoint,如果你今天模型架構已確定,或者是訓練結束,那請存成pb檔,pb 檔的詳細介紹我會放在更後面一點,今天我主要想介紹的是第一種 checkpoint 檔。
在開始介紹之前,你必須要先知道 checkpoint 檔的中心思想是保存 session 的狀態,所以你必須先宣告 tf.train.Saver()。
這邊示範撰寫一個簡單的小型網路。
input_node = tf.placeholder(shape=[None, 100, 100, 3], dtype=tf.float32)
net = tf.layers.conv2d(input_node, 32, (3, 3), strides=(2, 2), padding='same', name='conv_1')
net = tf.layers.conv2d(net, 32, (3, 3), strides=(1, 1), padding='same', name='conv_2')
net = tf.layers.conv2d(net, 64, (3, 3), strides=(2, 2), padding='same', name='conv_3')
產生的圖表如圖:
然後我再定義一個把第一層conv的第一個kernel weight印出來的方法。
def get_first_filter_value(sess):
tensor = tf.get_default_graph().get_tensor_by_name("conv_1/kernel/read:0")
return sess.run(tensor)[1, :, :, 1]
kernel weight 節點位置圖:
接著初始化參數後就可以把目前session存成ckpt了。
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
result = get_first_filter_value(sess)
saver.save(sess, "../ckpt/model.ckpt")
為了要驗證save之後在load回來,第一層conv的第一個kernel weight一樣,我們把剛剛的save包成方法,並回傳數值。
if __name__ == '__main__':
save_value = save()
load_value = load()
print(save_value)
print(load_value)
assert np.alltrue(save_value == load_value)
再來撰寫 load() 這個方法,在load ckpt時,有個重點需要注意!ckpt 的恢復分為兩個步驟,第一個步驟tf.train.import_meta_graph 是讀取meta 檔,meta 檔包含 tensor 的架構宣告。
第二個步驟 saver.restore 才是把剛剛tensor的權重值補回來,我這邊寫了一個小實驗,我在 restore 前,先init 初始值,最後拿到的 conv kernel weight 其實一樣,因為初始化之後,權重值被 restore 的蓋掉了。
def load():
tf.reset_default_graph()
with tf.Session() as sess:
saver = tf.train.import_meta_graph('../ckpt/model.ckpt.meta') # 步驟一
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
saver.restore(sess, '../ckpt/model.ckpt') # 步驟二
result = get_first_filter_value(sess)
return result
但!如果你先 restore ckpt 再 init 初始值,你會發現權重全部又變成初始值,和之前 save 的 weight 不一樣了。
def load():
tf.reset_default_graph()
with tf.Session() as sess:
saver = tf.train.import_meta_graph('../ckpt/model.ckpt.meta')
saver.restore(sess, '../ckpt/model.ckpt') # 實驗先restore
sess.run(tf.global_variables_initializer()) # 再init,會把restore的數值覆蓋掉!
sess.run(tf.local_variables_initializer())
result = get_first_filter_value(sess)
return result
實驗結果一,正常的 restore:
實驗結果二,先 restore 再init:
希望這個小實驗可以讓大家更理解每行程式碼的意義,而不是從 google 或 stackoverflow 上囫圇吞棗般亂貼上!