iT邦幫忙

第 11 屆 iThome 鐵人賽

DAY 26
0
Google Developers Machine Learning

How to Train Your Model 訓模高手:我的 Tensorflow 個人使用經驗系列文系列 第 26

【26】tensorflow 模型優化手術:推論時經過 dropout 很多餘?如何去掉 dropout layer 篇

  • 分享至 

  • xImage
  •  

介紹完預處理後,這次要來介紹的是 dropout 的去除,dropout 是訓練過程中,一種會隨機使輸出變成 0 的 layer,但如果我們目的只有推論的話,那其實可以把 dropout 完全拔掉。

我們先來看看目前的 dropout layer:
https://ithelp.ithome.com.tw/upload/images/20191004/2010729954Emjzh1Bf.png

然後在 opt 區塊多加 remove_dropout()

# opt start #

preprocess_gd = add_preprocessing('backend/conv_1/Conv2D', 'input_node')
update_graph(preprocess_gd)

no_training_gd = remove_training_node('training')
update_graph(no_training_gd)

no_dropout_gd = remove_dropout()
update_graph(no_dropout_gd)

# opt end #

由最上面的 graph 圖可知,如果要去除 dropout 我們要把上面的 MaxPool 跟下面的 relu_2 連起來,才算是將 dropout 脫離,所以我們會需要三個參數:被脫離的 dropout、脫離前和脫離後的節點:

def remove_dropout():
   old_gd = tf.get_default_graph().as_graph_def()
   new_gd = strip_dropout(old_gd,
                          drop_scope='backend/dropout_2',
                          dropout_before='backend/relu_2',
                          dropout_after='backend/max_pool_2/MaxPool')

   return new_gd

我們將脫離的方法獨立出來寫成 strip_dropout(),邏輯大致如下:先跑過每個node,如果是dropout scope 底下的就忽略。而如果遇到 dropout 的輸出節點,就檢查該節點的 input 是不是有 dropout_scope 的東西,有的話就把它改成 dropout 的輸入節點,否則保持原樣。

input_nodes = input_graph.node
nodes_after_strip = []
for node in input_nodes:  # for所有節點
   if node.name.startswith(drop_scope + '/'):  # drop_scope底下跳過
       continue

   new_node = node_def_pb2.NodeDef()  # 產生新節點
   new_node.CopyFrom(node)  # 拷貝舊節點資訊到新節點
   if new_node.name == dropout_after:  # 若該節點是after
       new_input = []
       for node_name in new_node.input:  # 檢查該節點的input
           if node.name.startswith(drop_scope + '/'):
               new_input.append(dropout_before)  # 若是drop_scope改成前節點
           else:
               new_input.append(node_name)  # 若不是,保持原input
       del new_node.input[:]
       new_node.input.extend(new_input)  # 更新input
   nodes_after_strip.append(new_node)  # 將新節點存到list

new_gd = graph_pb2.GraphDef()
new_gd.node.extend(nodes_after_strip)
return new_gd

接下來我們執行它來看看新生成的 graph:
https://ithelp.ithome.com.tw/upload/images/20191004/20107299ezfudZXfvl.png

很好,dropout 不見了,那再來檢查輸出:
https://ithelp.ithome.com.tw/upload/images/20191004/20107299Yz8zDpvZkN.png

數值相同,這個模型我給過!

github原始碼


上一篇
【25】tensorflow 模型優化手術:每次推論都要對 training node 設 false 很麻煩,如何鎖住 training node 篇
下一篇
【27】tensorflow 模型優化手術:除去冗贅的 Merge、Switch 運算節點,模型再剪枝篇
系列文
How to Train Your Model 訓模高手:我的 Tensorflow 個人使用經驗系列文31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言