本篇文章請搭配上篇使用,如果上篇 save pb 篇還沒看過的讀者建議往回跳。
好,今天要來介紹 load pb 檔的部分,相對於 checkpoint 的兩個步驟,load pb 就簡單許多,看個範例:
graph_def = tf.get_default_graph().as_graph_def()
with gfile.FastGFile(MODEL_PB, 'rb') as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
我們先把 default graph 的 graph_def 取出來,搭配 with 語句讀取 pb 檔(這邊這個 with 跟一般 python open file 意思一樣,如果覺得陌生的同學可以先查查),然後我們把這個 file parse 到剛剛宣告 的graph_def,這樣我們的 graph 就吃到囉!
pbtxt 也是很類似,我們要靠 google.protobuf.text_format 來讀取。
graph_def = tf.get_default_graph().as_graph_def()
with gfile.FastGFile(MODEL_PBTXT, 'rb') as f:
text_format.Parse(f.read(), graph_def)
tf.import_graph_def(graph_def, name='')
但是執行後等等你就會發現 tensorflow 噴錯了XD
google.protobuf.text_format.ParseError: 1563:10 : Message type "tensorflow.GraphDef" should not have multiple "versions" fields.
這是什麼意思呢?如果你前面 graph 概念都有弄熟,那你大致應該猜到了,沒錯!撞名啦!
我們來看這個 pbtxt 裡是否有 versions 欄位。
versions { producer: 38 }
找到了,pbtxt 裡確實有這個屬性,因此它和你的 default graph 互相產生衝突。
我自己的解決辦法是,我不要使用 default graph 來用,我自己再產生新的 graph,就可以了。
graph_def = tf.GraphDef()
with gfile.FastGFile(MODEL_PBTXT, 'rb') as f:
text_format.Parse(f.read(), graph_def)
tf.import_graph_def(graph_def, name='')
Bonus: 動腦時間:
Q: 最後一個範例中,為什麼我import FROZEN_PBTXT 用default graph就沒有問題?
.
.
.
.
.
.
.
.
.
.
.
A: 因為之前 tf.graph_util.convert_variables_to_constants() 在產生時,指定的 output node 不需要 vesrion 這個欄位,產生出來的 frozen graph 就把這個欄位丟掉了,因此就不會撞名啦!
再來,我們比較原 pb 和 frozen_pb 兩者的差異,以下是原 pb 模型圖:
以下是 frozen_pb 模型圖
可以很清楚的觀察到,原 pb conv layer 的權重值還是一個初始化 function,所以當你 load 進來時,無法馬上使用,必須先初始權重,反觀 frozen_pb ,它已經存有權重資訊,一旦 load 進來後,就可以直接做推斷,但化句話說,你也很難再對模型做變動(重新初始化訓練),畢竟它已經 frozen 了。
恩....不得不說 tensorflow 的 API 有時候潛藏一些因果要靠自己摸索呢!