iT邦幫忙

第 11 屆 iThome 鐵人賽

DAY 7
1
Google Developers Machine Learning

How to Train Your Model 訓模高手:我的 Tensorflow 個人使用經驗系列文系列 第 7

【07】tensorflow 細看存檔:save pb 篇

昨天介紹了有關 checkpoint 檔的存取,今天來介紹 pb 檔,幫大家複習一下,pb檔和checkpoint的差別主要是 pb 檔使用時機是你模型已確定,準備匯出應用時,普遍會存的檔案格式。

不像 checkpoint 需要 session,pb 檔在你建完節點網路時,就可以保存。

input_node = tf.placeholder(shape=[None, 100, 100, 3], dtype=tf.float32)
net = tf.layers.conv2d(input_node, 32, (3, 3), strides=(2, 2), padding='same', name='conv_1')
net = tf.layers.conv2d(net, 32, (3, 3), strides=(1, 1), padding='same', name='conv_2')
net = tf.layers.conv2d(net, 64, (3, 3), strides=(2, 2), padding='same', name='conv_3')

tf.io.write_graph(tf.get_default_graph(), "../pb/", "model.pb", as_text=False)

除了 pb 檔,你也可以存成 pbtxt 檔,好處是你可以用文字編輯器看格式。

tf.io.write_graph(tf.get_default_graph(), "../pb/", "model.pbtxt", as_text=True)

pbtxt 某部分txt格式:

node {
  name: "Placeholder"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        dim {
          size: -1
        }
        dim {
          size: 100
        }
        dim {
          size: 100
        }
        dim {
          size: 3
        }
      }
    }
  }
}
node {
  name: "conv_1/kernel/Initializer/random_uniform/shape"
  op: "Const"
  attr {
    key: "_class"
    value {
      list {
        s: "loc:@conv_1/kernel"
      }
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
          dim {
            size: 4
          }
        }
        tensor_content: "\003\000\000\000\003\000\000\000\003\000\000\000 \000\000\000"
      }
    }
  }
}

但是要注意的是,pbtxt比較佔空間,如果要應用還是建議存成binary的pb檔(即.pb)。
https://ithelp.ithome.com.tw/upload/images/20190915/20107299SejDwMZcKW.png

接下來有個問題,這樣的 pb 檔只能算是空殼,因為他只有網路架構,但是裡面沒有任何權重值啊!沒錯,所以這邊我來示範如何將權重一起保存進 pb 檔,一樣我們必須先 init 權重初始值。

再來我們需要 tf.graph_util.convert_variables_to_constants(),來封存權重值,要帶進去的參數很簡單,sessiongraph的定義你的 output 節點名稱,最基本的只需要這三樣,拿到封存的 graph 後 (frozen_graph),就可以保存下來,這邊一樣暫存成 pb 和 pbtxt 兩種格式。

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    frozen_graph = tf.graph_util.convert_variables_to_constants(
        sess, tf.get_default_graph().as_graph_def(), ['conv_3/BiasAdd'])

    tf.io.write_graph(frozen_graph, "../pb/", "frozen_model.pb", as_text=False)
    tf.io.write_graph(frozen_graph, "../pb/", "frozen_model.pbtxt", as_text=True)

從檔案大小來看,你的 pb 因為多了權重值變肥了,你也可以比對兩者 pbtxt,看看是多了哪些變數。
https://ithelp.ithome.com.tw/upload/images/20190915/20107299IsFhbwkyAn.png

可以觀察到 frozen_model.pbtxt 的 conv_1/kernel 多了 tensor_content 的權重值:
https://ithelp.ithome.com.tw/upload/images/20190915/20107299PTAnNc57fx.png

最後,有個很重要的觀念,上面 tf.graph_util.convert_variables_to_constants(),我們有指定 output 的節點,tensorflow 會根據這個節點往前面推測,總共要把哪些 node 保存下來,詳細的內容我會在之後 optimze 篇章再做更詳細的介紹。

有此可知,今天當你拿到一份 pb 檔時,你無法確定這個 pb 檔是已含權重或未含權重的模型檔,你必須讀取後才能得知,所以帶有權重的 pb 檔,我們習慣在名稱前多加 frozen 前綴來區別。

github原始碼


上一篇
【06】 tensorflow 細看存檔:checkpoint 篇
下一篇
【08】tensorflow 細看存檔:load pb篇
系列文
How to Train Your Model 訓模高手:我的 Tensorflow 個人使用經驗系列文31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

2 則留言

0
richard_wang
iT邦新手 5 級 ‧ 2021-02-22 13:02:46

大家好, 我想把 .pb 轉成 .pbtxt, 程式碼如下:

import tensorflow as tf
from tensorflow.python.platform import gfile

def convert_pb_to_pbtxt(filename):
    with tf.io.gfile.GFile(filename, 'rb') as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')
        tf.train.write_graph(graph_def, './', 'protobuf.pbtxt', as_text=True)
    return
    

filename = 'efficientnet-lite4-11.pb'
convert_pb_to_pbtxt(filename);

運行時遇到如下的問題:

D:\download\onnx\onnx_2_pb\efficientNet>python convert_to_bptxt.py
2021-02-22 11:44:21.746474: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'cudart64_100.dll'; dlerror: cudart64_100.dll not found
2021-02-22 11:44:21.748660: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
Traceback (most recent call last):
File "C:\Users\User\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow_core\python\framework\importer.py", line 501, in _import_graph_def_internal
graph._c_graph, serialized, options) # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: Node 'Softmax:0': Node name contains invalid characters

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "convert_to_bptxt.py", line 15, in
convert_pb_to_pbtxt(filepath);
File "convert_to_bptxt.py", line 9, in convert_pb_to_pbtxt
tf.import_graph_def(graph_def, name='')
File "C:\Users\User\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow_core\python\util\deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "C:\Users\User\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow_core\python\framework\importer.py", line 405, in import_graph_def
producer_op_list=producer_op_list)
File "C:\Users\User\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow_core\python\framework\importer.py", line 505, in _import_graph_def_internal
raise ValueError(str(e))
ValueError: Node 'Softmax:0': Node name contains invalid characters

D:\download\onnx\onnx_2_pb\efficientNet>
主要是:ValueError: Node 'Softmax:0': Node name contains invalid characters
我的python 版本是:
D:\download\onnx\onnx_2_pb\efficientNet>python --version
Python 3.7.8

tensorflow 版本是:
D:\download\onnx\onnx_2_pb\efficientNet>pip show tensorflow
Name: tensorflow
Version: 1.15.0
Summary: TensorFlow is an open source machine learning framework for everyone.
Home-page: https://www.tensorflow.org/
Author: Google Inc.
Author-email: packages@tensorflow.org
License: Apache 2.0
Location: c:\users\user\appdata\local\programs\python\python37\lib\site-packages
Requires: grpcio, google-pasta, opt-einsum, astor, termcolor, wrapt, gast, protobuf, numpy, absl-py, six, keras-applications, tensorflow-estimator, tensorboard, wheel, keras-preprocessing
Required-by:

想請教 這是什麼問題呢? 有什麼解決的方法呢?

Regards,

錯誤是 ValueError: Node 'Softmax:0': Node name contains invalid characters

似乎是設計網路最後的softmax取名有問題
你有用中文或奇怪的標點符號嗎?

你好
沒有, 這份.pb 是用 onnx-tf 從 onnx 轉出的, 如果onnx 原始就用這種命名 : xxxx:0, 或是 xxxx:1 , 就會在
tf.import_graph_def(graph_def, name='')
這一行 報錯.

看起來像是 ':' 造成的.

我要留言

立即登入留言