iT邦幫忙

第 11 屆 iT 邦幫忙鐵人賽

DAY 7
0

昨天介紹了有關 checkpoint 檔的存取,今天來介紹 pb 檔,幫大家複習一下,pb檔和checkpoint的差別主要是 pb 檔使用時機是你模型已確定,準備匯出應用時,普遍會存的檔案格式。

不像 checkpoint 需要 session,pb 檔在你建完節點網路時,就可以保存。

input_node = tf.placeholder(shape=[None, 100, 100, 3], dtype=tf.float32)
net = tf.layers.conv2d(input_node, 32, (3, 3), strides=(2, 2), padding='same', name='conv_1')
net = tf.layers.conv2d(net, 32, (3, 3), strides=(1, 1), padding='same', name='conv_2')
net = tf.layers.conv2d(net, 64, (3, 3), strides=(2, 2), padding='same', name='conv_3')

tf.io.write_graph(tf.get_default_graph(), "../pb/", "model.pb", as_text=False)

除了 pb 檔,你也可以存成 pbtxt 檔,好處是你可以用文字編輯器看格式。

tf.io.write_graph(tf.get_default_graph(), "../pb/", "model.pbtxt", as_text=True)

pbtxt 某部分txt格式:

node {
  name: "Placeholder"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        dim {
          size: -1
        }
        dim {
          size: 100
        }
        dim {
          size: 100
        }
        dim {
          size: 3
        }
      }
    }
  }
}
node {
  name: "conv_1/kernel/Initializer/random_uniform/shape"
  op: "Const"
  attr {
    key: "_class"
    value {
      list {
        s: "loc:@conv_1/kernel"
      }
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
          dim {
            size: 4
          }
        }
        tensor_content: "\003\000\000\000\003\000\000\000\003\000\000\000 \000\000\000"
      }
    }
  }
}

但是要注意的是,pbtxt比較佔空間,如果要應用還是建議存成binary的pb檔(即.pb)。
https://ithelp.ithome.com.tw/upload/images/20190915/20107299SejDwMZcKW.png

接下來有個問題,這樣的 pb 檔只能算是空殼,因為他只有網路架構,但是裡面沒有任何權重值啊!沒錯,所以這邊我來示範如何將權重一起保存進 pb 檔,一樣我們必須先 init 權重初始值。

再來我們需要 tf.graph_util.convert_variables_to_constants(),來封存權重值,要帶進去的參數很簡單,sessiongraph的定義你的 output 節點名稱,最基本的只需要這三樣,拿到封存的 graph 後 (frozen_graph),就可以保存下來,這邊一樣暫存成 pb 和 pbtxt 兩種格式。

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    frozen_graph = tf.graph_util.convert_variables_to_constants(
        sess, tf.get_default_graph().as_graph_def(), ['conv_3/BiasAdd'])

    tf.io.write_graph(frozen_graph, "../pb/", "frozen_model.pb", as_text=False)
    tf.io.write_graph(frozen_graph, "../pb/", "frozen_model.pbtxt", as_text=True)

從檔案大小來看,你的 pb 因為多了權重值變肥了,你也可以比對兩者 pbtxt,看看是多了哪些變數。
https://ithelp.ithome.com.tw/upload/images/20190915/20107299IsFhbwkyAn.png

可以觀察到 frozen_model.pbtxt 的 conv_1/kernel 多了 tensor_content 的權重值:
https://ithelp.ithome.com.tw/upload/images/20190915/20107299PTAnNc57fx.png

最後,有個很重要的觀念,上面 tf.graph_util.convert_variables_to_constants(),我們有指定 output 的節點,tensorflow 會根據這個節點往前面推測,總共要把哪些 node 保存下來,詳細的內容我會在之後 optimze 篇章再做更詳細的介紹。

有此可知,今天當你拿到一份 pb 檔時,你無法確定這個 pb 檔是已含權重或未含權重的模型檔,你必須讀取後才能得知,所以帶有權重的 pb 檔,我們習慣在名稱前多加 frozen 前綴來區別。

github原始碼


上一篇
【06】 tensorflow 細看存檔:checkpoint 篇
下一篇
【08】tensorflow 細看存檔:load pb篇
系列文
How to Train Your Model 訓模高手:我的 Tensorflow 個人使用經驗系列文31

尚未有邦友留言

立即登入留言