iT邦幫忙

第 11 屆 iThome 鐵人賽

DAY 6
0
Google Developers Machine Learning

How to Train Your Model 訓模高手:我的 Tensorflow 個人使用經驗系列文系列 第 6

【06】 tensorflow 細看存檔:checkpoint 篇

不管你用什麼訓練框架,終會遇到需要把過程或結果存起來的需求,一般來說,tensorflow 跟存檔有關的主要有兩種方式,一個是存成 checkpoint 檔,另一個是存成 pb 檔,這兩者在不同的場合會有不同用途,間單來說,如果你要保存目前session狀態,那我建議存成checkpoint,如果你今天模型架構已確定,或者是訓練結束,那請存成pb檔,pb 檔的詳細介紹我會放在更後面一點,今天我主要想介紹的是第一種 checkpoint 檔。

在開始介紹之前,你必須要先知道 checkpoint 檔的中心思想是保存 session 的狀態,所以你必須先宣告 tf.train.Saver()。

這邊示範撰寫一個簡單的小型網路。

input_node = tf.placeholder(shape=[None, 100, 100, 3], dtype=tf.float32)
net = tf.layers.conv2d(input_node, 32, (3, 3), strides=(2, 2), padding='same', name='conv_1')
net = tf.layers.conv2d(net, 32, (3, 3), strides=(1, 1), padding='same', name='conv_2')
net = tf.layers.conv2d(net, 64, (3, 3), strides=(2, 2), padding='same', name='conv_3')

產生的圖表如圖:
https://ithelp.ithome.com.tw/upload/images/20190914/20107299tM41MQSnSU.png

然後我再定義一個把第一層conv的第一個kernel weight印出來的方法。

def get_first_filter_value(sess):
    tensor = tf.get_default_graph().get_tensor_by_name("conv_1/kernel/read:0")
    return sess.run(tensor)[1, :, :, 1]

kernel weight 節點位置圖:
https://ithelp.ithome.com.tw/upload/images/20190914/201072998CPZ4928Bu.png

接著初始化參數後就可以把目前session存成ckpt了。

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    result = get_first_filter_value(sess)

    saver.save(sess, "../ckpt/model.ckpt")

為了要驗證save之後在load回來,第一層conv的第一個kernel weight一樣,我們把剛剛的save包成方法,並回傳數值。

if __name__ == '__main__':
    save_value = save()
    load_value = load()

    print(save_value)
    print(load_value)

    assert np.alltrue(save_value == load_value)

再來撰寫 load() 這個方法,在load ckpt時,有個重點需要注意!ckpt 的恢復分為兩個步驟,第一個步驟tf.train.import_meta_graph 是讀取meta 檔,meta 檔包含 tensor 的架構宣告。

第二個步驟 saver.restore 才是把剛剛tensor的權重值補回來,我這邊寫了一個小實驗,我在 restore 前,先init 初始值,最後拿到的 conv kernel weight 其實一樣,因為初始化之後,權重值被 restore 的蓋掉了。

def load():
    tf.reset_default_graph()
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('../ckpt/model.ckpt.meta')  # 步驟一

        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        saver.restore(sess, '../ckpt/model.ckpt')  # 步驟二

        result = get_first_filter_value(sess)

    return result

但!如果你先 restore ckpt 再 init 初始值,你會發現權重全部又變成初始值,和之前 save 的 weight 不一樣了。

def load():
    tf.reset_default_graph()
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('../ckpt/model.ckpt.meta')

        saver.restore(sess, '../ckpt/model.ckpt')  # 實驗先restore
        
        sess.run(tf.global_variables_initializer())  # 再init,會把restore的數值覆蓋掉!
        sess.run(tf.local_variables_initializer())

        result = get_first_filter_value(sess)

    return result

實驗結果一,正常的 restore:
https://ithelp.ithome.com.tw/upload/images/20190914/20107299MirxydjFB8.png

實驗結果二,先 restore 再init:
https://ithelp.ithome.com.tw/upload/images/20190914/20107299CghJjjmAuo.png

希望這個小實驗可以讓大家更理解每行程式碼的意義,而不是從 google 或 stackoverflow 上囫圇吞棗般亂貼上!

github原始碼


上一篇
【05】tensorflow 的 convolution 方式好多種,我該用哪個...
下一篇
【07】tensorflow 細看存檔:save pb 篇
系列文
How to Train Your Model 訓模高手:我的 Tensorflow 個人使用經驗系列文31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言