昨天的文章把訓練的技巧做了一個總複習,從今天開始到結束前的文章,我會著重在優化模型上,這是我摸索了好大一段時間才有理解的方法,文章內容相對進階,請大家期待囉!
在真正開始優化之前,我們必須複習一下產生 pb 檔的方法,day7 我介紹了 tf.graph_util.convert_variables_to_constants 這個方法來將訓練過程中不必要的節點去除,像是 gradient 數值、data reader 和 global step 等,我們在推論時完全不需要他們,就可以透過這個方法去除,那這個方法實際又是怎麼知道哪些節點該砍哪些該保留呢?
我們來粗略看一下 pbtxt 檔:
我們可以看到,運算最後的節點 final_dense/MatMul,其實有兩個 input,分別是 flatten 和 final_dense/kernel/read,而我們再來看看 flatten 節點:
可以看到 flatten 也有 input 分別是 backend/max_pool_3/MaxPool 和 flatten/shape,各位有感覺了嗎?沒錯!如果我知道最後輸出的節點名稱,我就可以利用這個節點的 input 一路往回推,一直推到 placeholder,藉此找完我需要保留的節點有哪些。
所以透過
saver = tf.train.import_meta_graph('../ckpt/model.ckpt-720.meta')
saver.restore(sess, '../ckpt/model.ckpt-720')
frozen_graph = tf.graph_util.convert_variables_to_constants(
sess, tf.get_default_graph().as_graph_def(), ['final_dense/MatMul'])
我們拿到的 frozen_graph 就是個保留我們所需的 graph 圖啦!
後面幾天的文章要談的優化,必須要建立兩個原則上:
1.優化完後,要確輸入與輸出之間的保節點仍然是相通的。
2.當丟一張圖片進去時,優化前和優化後的數值不能差太多。
為了要達成第一點,可以參考 tensorflow 官方寫了個方法 ensure_graph_is_valid,來檢查。
def ensure_graph_is_valid(graph_def):
node_map = {}
for node in graph_def.node:
if node.name not in node_map:
node_map[node.name] = node
else:
raise ValueError("Duplicate node names detected for ", node.name)
for node in graph_def.node:
for input_name in node.input:
input_node_name = node_name_from_input(input_name)
if input_node_name not in node_map:
raise ValueError("Input for ", node.name, " not found: ",
input_name)
def node_name_from_input(node_name):
if node_name.startswith("^"):
node_name = node_name[1:]
m = re.search(r"(.*):\d+$", node_name)
if m:
node_name = m.group(1)
return node_name
很多 for 迴圈...XD,第一個 for 先檢查所有節點有沒有重覆的,並把節點存到 node_map,第二個 for 再跑一次所有節點,檢查每個節點的 input 有沒有在 node_map 裡,就醬!
而關於第二個檢查,我這邊自己定義了一個 test 的方法:
def test():
graph_def = tf.get_default_graph().as_graph_def()
with gfile.FastGFile('../pb/frozen_shape_23.pb', 'rb') as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
input_node = tf.get_default_graph().get_tensor_by_name(
"input_node:0")
training_node = tf.get_default_graph().get_tensor_by_name(
"training:0")
output_node = tf.get_default_graph().get_tensor_by_name(
"final_dense/MatMul:0")
image = cv2.imread('../05/ithome.jpg')
image = cv2.resize(image, (128, 128))
image = image - 127.5
image = image * 0.0078125
output = sess.run(output_node, feed_dict={input_node: np.expand_dims(image, 0), training_node: False})
print(output)
內容就是丟一張圖片進去,在最後 final_dense/MatMul 要拿到相同的數值,執行如下:
後面幾天介紹的優化,我必須要確保能產生相同的數值結果,請大家期待囉!