iT邦幫忙

2022 iThome 鐵人賽

DAY 17
0
AI & Data

JAX 好好玩系列 第 17

JAX 好好玩 (17) : JAX JIT (6) : 解密全域變數的怪異行為

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載

老頭在之前「純函式 (Pure Function)」的貼文中,曾經舉了兩個例子,說明在 JIT 函式內使用全域變數所產生的怪異現象,現在我們就來看看這些怪異行為背後的原因。

回顧第一個例子:

g1 = 0.
def impure_uses_globals(x):
  return x + g1

impure_jit_fun = jax.jit(impure_uses_globals)

# 1st call, at this moment g1 = 0
print(f'1st call: {impure_jit_fun(99.9)}')
print('===========================================')

# not change g1 to 1.
g1 = 1.
print(f'Currently g1 = {g1}')
print('===========================================')

# 2nd call, at this moment g1 = 0
print(f'2nd call: {impure_jit_fun(99.9)}')

output:
1st call: 99.9000015258789
'==========================================='
Currently g1 = 1.0
'==========================================='
2nd call: 99.9000015258789

不管我們怎麼改變全域變數 g1 的值,函式 impure_uses_globals() 「似乎」永遠都把 g1 視為 0. !為什麼?關鍵在於 JIT 的「追踪」上。

我們改寫一下前面的程式片斷,在第一次呼叫 impure_uses_globals() 之前,把 JAX 表示式印出來,看看 JAX 是怎麼追踪它的:

g1 = 0.
 
@jax.jit
def impure_uses_globals(x):
    print(f'x = {x}')
    print(f'g1 = {g1}')
    print(f'==========================')
    return x + g1
 
jax.make_jaxpr(impure_uses_globals)(99.9)

output:
https://ithelp.ithome.com.tw/upload/images/20220919/20129616qlCSJxRYpx.png

對於輸入參數 x , JAX 把它設為一個追踪物件 (tracer object),其維度為 0 (純量) 型態為 float32。而 g1 則被追踪為一個常數 0.0,這個就是關鍵!在 JAX 表示式中,g1 不再是變數,而是常數 0.0,因此,在第二次呼叫函式的時候,儘管當時全域變數 g1 的值已經被改掉了,但是 JAX 表示式中,它還是 0.0。
https://ithelp.ithome.com.tw/upload/images/20220919/20129616OndSHEB5Lg.png

再來看看之前的第二個例子:

g2 = 0.
def impure_saves_global(x):
  global g2
  g2 = x
  return x

impure_jit_fun = jax.jit(impure_saves_global)

print(f'Before 1st call, g2 = {g2}')
print('===========================================')
print(f'1st call: {impure_jit_fun(1.0)}')
print('===========================================')
print(f'After 1st call, g2 = {g2}')

output:
Before 1st call, g2 = 0.0
'==========================================='
1st call: 1.0
'==========================================='
After 1st call, g2 = Traced<ShapedArray(float32[],
weak_type=True)>with<DynamicJaxprTrace(level=0/1)>

把它重寫之後,我們就可以很容易看出到底發生了什麼事:

g2 = 0.
def impure_saves_global(x):
    global g2
    print(f'input parameter x = {x}')
    print(f'Befoer assignment, g2 = {g2}')
    g2 = x
    print(f'After  assignment, g2 = {g2}')
    print('===========================================')
    return x
 
impure_jit_fun = jax.jit(impure_saves_global)
 
print(f'1st call: {impure_jit_fun(1.0)}')

output:
input parameter x = Traced<ShapedArray(float32[],
weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
Befoer assignment, g2 = 0.0
After assignment, g2 = Traced<ShapedArray(float32[],
weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
'==========================================='
1st call: 1.0

一目了然了吧!因為輸入參數 x 在追踪時會成為一個追踪物件 (tracer object) ,在函式裏把這個追踪物件指定給全域變數 g2,g2 也就成為一個追踪物件了。很瞎,追踪物件在 JIT 函式外幾乎沒有什麼用途,這個結果當然不是程式設計者的原始想法。

總而言之,不要在 JIT 函式內使用全域變數,除非你真的知道使用它會造成的結果。

最後再舉一個於 JAX 官方文件中的例子作為此次貼文的結束 [17.1]:

global_list = []
 
def log2(x):
  global_list.append(x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2
 
print(jax.make_jaxpr(log2)(3.0))

output:
https://ithelp.ithome.com.tw/upload/images/20220919/20129616cRhZQCOP9b.png

https://ithelp.ithome.com.tw/upload/images/20220919/20129616lX64OAZFMY.png

我們仔細分析一下 JIT 函式 log2() 的 JAX 表示式,可以看到在此表示式中,跟本沒有包括和全域變數 global_list 相關的運算,也就是說,JAX 在追踪函式時,它不會把那些造成函式不純的副作用 (如 print, 全域變數等) 放入 JAX 表示式裏,這些副作用也只有在函式被追踪時執行一次。

要提醒大家的是,這種副作用被執行一次的行為,只是目前 JAX 實作時的一種選擇,JAX 並不排除未來會改變這樣的設計,改成一次都不執行也是有可能的,所以我們的 JAX 程式運行正確與否,絕對不要依賴這些副作用。

[17.1] 這一個例子來自 JAX 文件官網 How JAX transforms work 一文的內容。


上一篇
JAX 好好玩 (16) : JAX JIT (5) : 如何不被追踪
下一篇
JAX 好好玩 (18) : JAX JIT (7) : 總結與回顧
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言