iT邦幫忙

第 11 屆 iThome 鐵人賽

DAY 27
0
Google Developers Machine Learning

How to Train Your Model 訓模高手:我的 Tensorflow 個人使用經驗系列文系列 第 27

【27】tensorflow 模型優化手術:除去冗贅的 Merge、Switch 運算節點,模型再剪枝篇

  • 分享至 

  • xImage
  •  

既然 dropout 已經被去除了,接下來我們把目標瞄準到 batch normalization 優化吧,那該優化什麼呢?我們先來看看 batch normalization 在 tensorboard 上的狀態吧。
https://ithelp.ithome.com.tw/upload/images/20191005/20107299Tn658XXnw9.png

第一點:
有兩個 FusedBatchNorm ,在 day10 介紹了 batch normalization 在訓練和推論兩個模式下會有不同的行為,其中這次我們感興趣的是 FusedBatchNorm_1 這個節點,因為它含有 moving_mean 和 moving_var 是堆論模式下會所會用到的參數。
https://ithelp.ithome.com.tw/upload/images/20191005/20107299hcLoeLKLso.png

而從第一張圖來看,這兩個 FusedBatchNorm 都會指向到 merge 這個節點,所以我們優化的第一個重點就是只留 FusedBatchNorm_1 這條路啦!

那我們來撰寫 remove_merge 這個方法,這邊有個重點就是 Merge 這個節點會有兩個 input ,input[0] 的要保留,input[1] 不保留。

def remove_merge():
   old_gd = tf.get_default_graph().as_graph_def()
   old_nodes = old_gd.node

   names_to_remove = {}  # 產生一個dict記錄等等哪些節點要被砍掉
   for node in old_nodes:  # for所有節點
       if node.op == 'Merge':  # 當發現節點是Merge時
           # FusedBatchNorm_1=0, FusedBatchNorm=1
           # 將節點存到key,位置0存到value
           names_to_remove[node.name] = node.input[0]
           # 位置1因為要被砍,所以只存key
           names_to_remove[node.input[1]] = None
   print(f'remove merge: {names_to_remove}')

為什麼只保留 input[0] 而不保留 input[1] 呢?可以看到下圖中右上方資訊,第一個 input 是 FusedBatchNorm_1,位於 input[0],而我們不要的 FusedBatchNorm 是第二個輸入位於 input[1],所以 names_to_remove 就塞個 None 給它。
https://ithelp.ithome.com.tw/upload/images/20191005/20107299u2EiTsXTc0.png

而 remove_merge 的後半段,跟前幾天的邏輯差不多,不一樣的地方在於多了要從 names_to_remove 找上個輸出這件事情,所以會多了 while 迴圈。

nodes_after_modify = []
for node in old_nodes:  # for所有節點
   if node.name in names_to_remove:  # 如果是要被砍掉的就跳過
       continue
   new_node = node_def_pb2.NodeDef()
   new_node.CopyFrom(node)
   input_before_removal = node.input
   del new_node.input[:]
   for full_input_name in input_before_removal:  # 檢查input節點
       # 去除名稱中其他符號
       input_name = re.sub(r"^\^", "", full_input_name)
       # 如果input_name是要被砍的,就進迴圈
       while input_name in names_to_remove:
           # 找上個輸入
           full_input_name = names_to_remove[input_name]
           # 去除名稱中其他符號
           input_name = re.sub(r"^\^", "", full_input_name)
         # 把FusedBatchNorm_1加到新節點
       new_node.input.append(full_input_name)
   nodes_after_modify.append(new_node)  # 新增該節點

new_gd = graph_pb2.GraphDef()
new_gd.node.extend(nodes_after_modify)
return new_gd

去除 merge 後的 graph,可以看到只留下 FusedBatchNorm_1 囉:
https://ithelp.ithome.com.tw/upload/images/20191005/20107299NIp9CU6pfN.png

但各位有沒有注意到,圖中仍有個 false_node 指向 pred_id 節點再連接到 switch,這個東西就是用來控制要不要拿 gamma, beta, moving_mean, mean_var 的控制器,既然我們現在只要推論,那是不是連 switch 都可以拿掉呢?

當然可以,我們來撰寫 remove_switch() 的前半部,這次目標是找 switch,並且將這個 switch 的 input 檢查一遍,把 pred_id 以外的 input 存到 names_to_remove 這個 dict 中。

def remove_switch():
   old_gd = tf.get_default_graph().as_graph_def()
   old_nodes = old_gd.node

   names_to_remove = {}  # 產生一個dict記錄哪些節點要被砍掉
   for node in old_nodes:  # for所有節點
       if node.op != 'Switch':  # 當發現節點不是Switch時跳過
           continue
       for node_i in node.input:  # 檢查Switch的input
           # 若input是pred_id,把pred_id以外的input存到dict
           if node_i.split('/')[-1] == 'pred_id':
               names_to_remove[node.name] = [
                   x for x in node.input if x.split('/')[-1] != 'pred_id'

後半段的部分也和之前差不多,即跑一次所有節點,然後檢查每個節點的 input 是不是 switch,如果是就把上半部 dict 中的 input 補回去,程式碼如下:

nodes_after_modify = []
for node in old_nodes:  # for所有節點
   if node.name in names_to_remove:
       continue  # 如果是要被砍掉的就跳過
   new_node = node_def_pb2.NodeDef()  # 產生新節點
   new_node.CopyFrom(node)  # 拷貝舊節點資訊
   input_before_removal = node.input
   del new_node.input[:]  # 去掉所有input
   for full_input_name in input_before_removal:
       if full_input_name in names_to_remove:  # 如果此節點input是switch
           for input_name in names_to_remove[full_input_name]:
               # 就把names_to_remove的所有input補上
               new_node.input.append(input_name)
       else:  # 不然一切照舊
           new_node.input.append(full_input_name)
   nodes_after_modify.append(new_node)

new_gd = graph_pb2.GraphDef()
new_gd.node.extend(nodes_after_modify)
return new_gd

再來看看執行後的成果吧:
https://ithelp.ithome.com.tw/upload/images/20191005/201072991xBl919ioa.png

整個 FusedBatchNorm_1 變成一個節點,輸入直接連上 gamma, beta, moving_mean, mean_var 沒經過 switch,一看就很清爽!

再來是驗證數值的時候啦!
https://ithelp.ithome.com.tw/upload/images/20191005/20107299FcEBJUqpqb.png

數值一樣,PASS !

github原始碼


上一篇
【26】tensorflow 模型優化手術:推論時經過 dropout 很多餘?如何去掉 dropout layer 篇
下一篇
【28】tensorflow 模型優化手術:給我折下去!模型 folding batch normalization 篇
系列文
How to Train Your Model 訓模高手:我的 Tensorflow 個人使用經驗系列文31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言