不知道大家有沒有遇到一種情況,就是你用 jupyter notebook 或其他 IDE 分段執行程式碼時,run 節點架構時,tensorflow 噴出錯誤告訴你節點錯誤,這個就是新手常犯的錯誤:在同個 graph 重覆宣告啦!
首先,我們必須要知道tensorflow運作原理,tensorflow的運作原理就像是在堆骨牌一樣,你必須先將整個結構定義好,接著才能丟給session去執行,因此,很多新手就會因為重覆執行程式碼,導致 tensorflow 在同個 graph 上畫出一樣的東西而導致錯誤。
(如上圖,當你在 jupyter notebook 對第二個 cell 按了好幾次執行,其實也在 graph 裡產生很多個 c0常數)
Okay 那我們該如何管理這些graph呢?
第一點,tensoflow 有預設的 graph,如果你沒有特別指定,那麼宣告的 tensor 節點都是在預設的 graph 上,如果需要這張預設的 graph,你可以呼叫
tf.get_default_graph()
來取得。
第二點,如果你新開一張 graph 使用,你可以使用
with tf.Graph().as_default()
如此,你以下所宣告的節點就會在此 graph 上。
這邊為大家示範,在預設 graph 和另外兩個新的 graph 使用的程式碼範例。
tf.constant(0, name="c0")
with tf.Graph().as_default() as g1:
tf.constant(1, name="c1")
with tf.Graph().as_default() as g2:
tf.constant(2, name="c2")
我在預設的 graph 宣告 c0這個常數,在 g1 宣告 c1,g2 宣告 c2,最後個別產生三個不同的 session 在各自不同的 graph 上執行。
sess0 = tf.Session()
sess1 = tf.Session(graph=g1)
sess2 = tf.Session(graph=g2)
t0 = tf.get_default_graph().get_tensor_by_name('c0:0')
t1 = sess1.graph.get_tensor_by_name('c1:0')
t2 = sess2.graph.get_tensor_by_name('c2:0')
result0 = sess0.run(t0)
result1 = sess1.run(t1)
result2 = sess2.run(t2)
print("result 0: {}".format(result0))
print("result 1: {}".format(result1))
print("result 2: {}".format(result2))
print("default graph: {}".format([n.name for n in tf.get_default_graph().as_graph_def().node]))
print("graph 1: {}".format([n.name for n in g1.as_graph_def().node]))
print("graph 2: {}".format([n.name for n in g2.as_graph_def().node]))
sess0.close()
sess1.close()
sess2.close()
結果印出。
以上就是 tensorflow 多個 graph 的使用方式。