iT邦幫忙

第 11 屆 iThome 鐵人賽

DAY 23
0

昨天的文章把訓練的技巧做了一個總複習,從今天開始到結束前的文章,我會著重在優化模型上,這是我摸索了好大一段時間才有理解的方法,文章內容相對進階,請大家期待囉!

在真正開始優化之前,我們必須複習一下產生 pb 檔的方法,day7 我介紹了 tf.graph_util.convert_variables_to_constants 這個方法來將訓練過程中不必要的節點去除,像是 gradient 數值、data reader 和 global step 等,我們在推論時完全不需要他們,就可以透過這個方法去除,那這個方法實際又是怎麼知道哪些節點該砍哪些該保留呢?
https://ithelp.ithome.com.tw/upload/images/20191001/20107299JmAGUKJj8p.png

我們來粗略看一下 pbtxt 檔:
https://ithelp.ithome.com.tw/upload/images/20191001/201072999fZRPxOc6a.png

我們可以看到,運算最後的節點 final_dense/MatMul,其實有兩個 input,分別是 flatten 和 final_dense/kernel/read,而我們再來看看 flatten 節點:
https://ithelp.ithome.com.tw/upload/images/20191001/20107299KaZIeiExx5.png

可以看到 flatten 也有 input 分別是 backend/max_pool_3/MaxPool 和 flatten/shape,各位有感覺了嗎?沒錯!如果我知道最後輸出的節點名稱,我就可以利用這個節點的 input 一路往回推,一直推到 placeholder,藉此找完我需要保留的節點有哪些。

所以透過

saver = tf.train.import_meta_graph('../ckpt/model.ckpt-720.meta')
saver.restore(sess, '../ckpt/model.ckpt-720')

frozen_graph = tf.graph_util.convert_variables_to_constants(
   sess, tf.get_default_graph().as_graph_def(), ['final_dense/MatMul'])

我們拿到的 frozen_graph 就是個保留我們所需的 graph 圖啦!
https://ithelp.ithome.com.tw/upload/images/20191001/20107299UdycVStVYI.png

後面幾天的文章要談的優化,必須要建立兩個原則上:
1.優化完後,要確輸入與輸出之間的保節點仍然是相通的。
2.當丟一張圖片進去時,優化前和優化後的數值不能差太多。

為了要達成第一點,可以參考 tensorflow 官方寫了個方法 ensure_graph_is_valid,來檢查。

def ensure_graph_is_valid(graph_def):
   node_map = {}
   for node in graph_def.node:
       if node.name not in node_map:
           node_map[node.name] = node
       else:
           raise ValueError("Duplicate node names detected for ", node.name)
   for node in graph_def.node:
       for input_name in node.input:
           input_node_name = node_name_from_input(input_name)
           if input_node_name not in node_map:
               raise ValueError("Input for ", node.name, " not found: ",
                                input_name)

def node_name_from_input(node_name):
   if node_name.startswith("^"):
       node_name = node_name[1:]
   m = re.search(r"(.*):\d+$", node_name)
   if m:
       node_name = m.group(1)
   return node_name

很多 for 迴圈...XD,第一個 for 先檢查所有節點有沒有重覆的,並把節點存到 node_map,第二個 for 再跑一次所有節點,檢查每個節點的 input 有沒有在 node_map 裡,就醬!

而關於第二個檢查,我這邊自己定義了一個 test 的方法:

def test():
   graph_def = tf.get_default_graph().as_graph_def()
   with gfile.FastGFile('../pb/frozen_shape_23.pb', 'rb') as f:
       graph_def.ParseFromString(f.read())
   tf.import_graph_def(graph_def, name='')

   with tf.Session() as sess:
       input_node = tf.get_default_graph().get_tensor_by_name(
           "input_node:0")
       training_node = tf.get_default_graph().get_tensor_by_name(
           "training:0")
       output_node = tf.get_default_graph().get_tensor_by_name(
           "final_dense/MatMul:0")

       image = cv2.imread('../05/ithome.jpg')
       image = cv2.resize(image, (128, 128))
       image = image - 127.5
       image = image * 0.0078125

       output = sess.run(output_node, feed_dict={input_node: np.expand_dims(image, 0), training_node: False})
       print(output)

內容就是丟一張圖片進去,在最後 final_dense/MatMul 要拿到相同的數值,執行如下:
https://ithelp.ithome.com.tw/upload/images/20191001/20107299Chfl53s5Uy.png

後面幾天介紹的優化,我必須要確保能產生相同的數值結果,請大家期待囉!

github原始碼


上一篇
【22】tensorflow 訓練技巧觀念混合運用篇
下一篇
【24】tensorflow 模型優化手術:把一般輸入改成 pre processing 輸入篇
系列文
How to Train Your Model 訓模高手:我的 Tensorflow 個人使用經驗系列文31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言