介紹完預處理後,這次要來介紹的是 dropout 的去除,dropout 是訓練過程中,一種會隨機使輸出變成 0 的 layer,但如果我們目的只有推論的話,那其實可以把 dropout 完全拔掉。
我們先來看看目前的 dropout layer:
然後在 opt 區塊多加 remove_dropout()
# opt start #
preprocess_gd = add_preprocessing('backend/conv_1/Conv2D', 'input_node')
update_graph(preprocess_gd)
no_training_gd = remove_training_node('training')
update_graph(no_training_gd)
no_dropout_gd = remove_dropout()
update_graph(no_dropout_gd)
# opt end #
由最上面的 graph 圖可知,如果要去除 dropout 我們要把上面的 MaxPool 跟下面的 relu_2 連起來,才算是將 dropout 脫離,所以我們會需要三個參數:被脫離的 dropout、脫離前和脫離後的節點:
def remove_dropout():
old_gd = tf.get_default_graph().as_graph_def()
new_gd = strip_dropout(old_gd,
drop_scope='backend/dropout_2',
dropout_before='backend/relu_2',
dropout_after='backend/max_pool_2/MaxPool')
return new_gd
我們將脫離的方法獨立出來寫成 strip_dropout(),邏輯大致如下:先跑過每個node,如果是dropout scope 底下的就忽略。而如果遇到 dropout 的輸出節點,就檢查該節點的 input 是不是有 dropout_scope 的東西,有的話就把它改成 dropout 的輸入節點,否則保持原樣。
input_nodes = input_graph.node
nodes_after_strip = []
for node in input_nodes: # for所有節點
if node.name.startswith(drop_scope + '/'): # drop_scope底下跳過
continue
new_node = node_def_pb2.NodeDef() # 產生新節點
new_node.CopyFrom(node) # 拷貝舊節點資訊到新節點
if new_node.name == dropout_after: # 若該節點是after
new_input = []
for node_name in new_node.input: # 檢查該節點的input
if node.name.startswith(drop_scope + '/'):
new_input.append(dropout_before) # 若是drop_scope改成前節點
else:
new_input.append(node_name) # 若不是,保持原input
del new_node.input[:]
new_node.input.extend(new_input) # 更新input
nodes_after_strip.append(new_node) # 將新節點存到list
new_gd = graph_pb2.GraphDef()
new_gd.node.extend(nodes_after_strip)
return new_gd
接下來我們執行它來看看新生成的 graph:
很好,dropout 不見了,那再來檢查輸出:
數值相同,這個模型我給過!