iT邦幫忙

第 11 屆 iThome 鐵人賽

DAY 25
1
Google Developers Machine Learning

How to Train Your Model 訓模高手:我的 Tensorflow 個人使用經驗系列文系列 第 25

【25】tensorflow 模型優化手術:每次推論都要對 training node 設 false 很麻煩,如何鎖住 training node 篇

  • 分享至 

  • xImage
  •  

昨天介紹了對模型加入預處理的修改,今天要介紹的就是 training 節點的修改啦,原本的模型中,我們已用了 placeholder_with_default 來只指定預設值,但是如果你先前的模型是直接使用 placeholder 的話,會變成每次推論都要指定布林值,這時就可以用上今天的方法將 training 節點封住。

和昨天預處理不一樣的地方在於,training 節點指向多個 layer,包括 dropout 和 batch normalization 都有,因此,我們必須每個節點跑一輪後,檢查開節點的 input 是否有 training,如果有就改用 false constant 堵住。

跟上次的程式碼差不多,只是這次優化中,我們在 add_preprocessing() 後面多了 remove_training_node() 方法。

# opt start #

preprocess_gd = add_preprocessing('backend/conv_1/Conv2D', 'input_node')
update_graph(preprocess_gd)

no_training_gd = remove_training_node('training')
update_graph(no_training_gd)

# opt end #

而 remove_training_node() 前半段,我們先宣告所需要的新節點跟創建等等用來產生新節點的 list:

def remove_training_node(is_training_node):
   false_node = tf.constant(False, dtype=tf.bool, shape=(), name='false_node')

   old_gd = tf.get_default_graph().as_graph_def()
   old_nodes = old_gd.node  # old nodes from graph

   nodes_after_modify = []

接著,就是中間處理的部份,我們一樣透過 for 迴圈繞一次所有節點並拷貝資料,然後把 input 砍掉,替換成新的 input,如果 input 中是 training 節點就換成 false_node,反之則保留原本 input,詳細實作如下:

for node in old_nodes:
   new_node = node_def_pb2.NodeDef()  # 產生新節點
   new_node.CopyFrom(node)  # 拷貝舊節點資訊到新節點
   input_before_removal = node.input  # 把舊節點的inputs暫存起來
   del new_node.input[:]  # 就把該inputs全部去除
   for full_input_name in input_before_removal:  # 然後再for跑一次剛剛刪除的inputs
       if full_input_name == is_training_node:  # inputs中若有training_node
           new_node.input.append(false_node.op.name)  # 改塞false_node給它
       else:
           new_node.input.append(full_input_name)  # 不是的話,維持原先的input
   nodes_after_modify.append(new_node)  # 將新節點存到list

new_gd = graph_pb2.GraphDef()  # 產生新graph def
new_gd.node.extend(nodes_after_modify)  # 在新graph def中生成那些新節點後return
return new_gd

執行 session 時,feed_dict 的 training 節點就可以不必再指定啦!

with tf.Session() as sess:
    input_node = tf.get_default_graph().get_tensor_by_name(
        "new_input_node:0")
    output_node = tf.get_default_graph().get_tensor_by_name(
        "final_dense/MatMul:0")

    image = cv2.imread('../05/ithome.jpg')
    image = cv2.resize(image, (128, 128))

    output = sess.run(output_node, feed_dict={input_node: np.expand_dims(image, 0)})
    print(output)

我們來看看 graph 的變化,這是執行前:
https://ithelp.ithome.com.tw/upload/images/20191003/20107299Y1lRdoxtKX.png

這是執行後的 bn:
https://ithelp.ithome.com.tw/upload/images/20191003/20107299SD05fGgzAc.png

和執行後的 dropout:
https://ithelp.ithome.com.tw/upload/images/20191003/20107299NcKJKKHlFZ.png

整個模型來看,training 節點不見了。
https://ithelp.ithome.com.tw/upload/images/20191003/20107299itMVALAUkL.png

輸出的話,也可以看到數字沒改變,模型的邏輯還是一樣的!
https://ithelp.ithome.com.tw/upload/images/20191003/201072995kHfDCJhS8.png

github原始碼


上一篇
【24】tensorflow 模型優化手術:把一般輸入改成 pre processing 輸入篇
下一篇
【26】tensorflow 模型優化手術:推論時經過 dropout 很多餘?如何去掉 dropout layer 篇
系列文
How to Train Your Model 訓模高手:我的 Tensorflow 個人使用經驗系列文31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言