(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載
老頭在之前「純函式 (Pure Function)」的貼文中,曾經舉了兩個例子,說明在 JIT 函式內使用全域變數所產生的怪異現象,現在我們就來看看這些怪異行為背後的原因。
回顧第一個例子:
g1 = 0.
def impure_uses_globals(x):
return x + g1
impure_jit_fun = jax.jit(impure_uses_globals)
# 1st call, at this moment g1 = 0
print(f'1st call: {impure_jit_fun(99.9)}')
print('===========================================')
# not change g1 to 1.
g1 = 1.
print(f'Currently g1 = {g1}')
print('===========================================')
# 2nd call, at this moment g1 = 0
print(f'2nd call: {impure_jit_fun(99.9)}')
output:
1st call: 99.9000015258789
'==========================================='
Currently g1 = 1.0
'==========================================='
2nd call: 99.9000015258789
不管我們怎麼改變全域變數 g1 的值,函式 impure_uses_globals() 「似乎」永遠都把 g1 視為 0. !為什麼?關鍵在於 JIT 的「追踪」上。
我們改寫一下前面的程式片斷,在第一次呼叫 impure_uses_globals() 之前,把 JAX 表示式印出來,看看 JAX 是怎麼追踪它的:
g1 = 0.
@jax.jit
def impure_uses_globals(x):
print(f'x = {x}')
print(f'g1 = {g1}')
print(f'==========================')
return x + g1
jax.make_jaxpr(impure_uses_globals)(99.9)
output:
對於輸入參數 x , JAX 把它設為一個追踪物件 (tracer object),其維度為 0 (純量) 型態為 float32。而 g1 則被追踪為一個常數 0.0,這個就是關鍵!在 JAX 表示式中,g1 不再是變數,而是常數 0.0,因此,在第二次呼叫函式的時候,儘管當時全域變數 g1 的值已經被改掉了,但是 JAX 表示式中,它還是 0.0。
再來看看之前的第二個例子:
g2 = 0.
def impure_saves_global(x):
global g2
g2 = x
return x
impure_jit_fun = jax.jit(impure_saves_global)
print(f'Before 1st call, g2 = {g2}')
print('===========================================')
print(f'1st call: {impure_jit_fun(1.0)}')
print('===========================================')
print(f'After 1st call, g2 = {g2}')
output:
Before 1st call, g2 = 0.0
'==========================================='
1st call: 1.0
'==========================================='
After 1st call, g2 = Traced<ShapedArray(float32[],
weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
把它重寫之後,我們就可以很容易看出到底發生了什麼事:
g2 = 0.
def impure_saves_global(x):
global g2
print(f'input parameter x = {x}')
print(f'Befoer assignment, g2 = {g2}')
g2 = x
print(f'After assignment, g2 = {g2}')
print('===========================================')
return x
impure_jit_fun = jax.jit(impure_saves_global)
print(f'1st call: {impure_jit_fun(1.0)}')
output:
input parameter x = Traced<ShapedArray(float32[],
weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
Befoer assignment, g2 = 0.0
After assignment, g2 = Traced<ShapedArray(float32[],
weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
'==========================================='
1st call: 1.0
一目了然了吧!因為輸入參數 x 在追踪時會成為一個追踪物件 (tracer object) ,在函式裏把這個追踪物件指定給全域變數 g2,g2 也就成為一個追踪物件了。很瞎,追踪物件在 JIT 函式外幾乎沒有什麼用途,這個結果當然不是程式設計者的原始想法。
總而言之,不要在 JIT 函式內使用全域變數,除非你真的知道使用它會造成的結果。
最後再舉一個於 JAX 官方文件中的例子作為此次貼文的結束 [17.1]:
global_list = []
def log2(x):
global_list.append(x)
ln_x = jnp.log(x)
ln_2 = jnp.log(2.0)
return ln_x / ln_2
print(jax.make_jaxpr(log2)(3.0))
output:
我們仔細分析一下 JIT 函式 log2() 的 JAX 表示式,可以看到在此表示式中,跟本沒有包括和全域變數 global_list 相關的運算,也就是說,JAX 在追踪函式時,它不會把那些造成函式不純的副作用 (如 print, 全域變數等) 放入 JAX 表示式裏,這些副作用也只有在函式被追踪時執行一次。
要提醒大家的是,這種副作用被執行一次的行為,只是目前 JAX 實作時的一種選擇,JAX 並不排除未來會改變這樣的設計,改成一次都不執行也是有可能的,所以我們的 JAX 程式運行正確與否,絕對不要依賴這些副作用。
[17.1] 這一個例子來自 JAX 文件官網 How JAX transforms work 一文的內容。