iT邦幫忙

2022 iThome 鐵人賽

DAY 16
0
AI & Data

JAX 好好玩系列 第 16

JAX 好好玩 (16) : JAX JIT (5) : 如何不被追踪

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載

上一次的貼文提到以上這個例子,它在執行時會報錯。

@jax.jit
def f2(x, neg):
  return -x if neg else x # will raise run-time error
 
f2(1, True)

理由是在 JIT 函式中,程式的控制邏輯不能取決於「追踪物件」的值。然而,如果在我們的程式中必須有這樣的設計時,應該怎麼辦呢 ? 一個解決方法是,讓這個輸入參數「不要被追踪」。我們可以利用 Python 的偏函式 (partial) ,將 neg 宣告為「靜態輸入參數」:

from functools import partial
 
@partial(jax.jit, static_argnums=(1,))
def f3(x, neg):
  return -x if neg else x
 
x=1
f3(1, True)

output:
DeviceArray(-1, dtype=int32)

這樣,執行時就不會報錯了。更進一步來看靜態輸入參數,我們把 f3() 改寫成 f4(),刻意印出它的兩個參數:

@partial(jax.jit, static_argnums=(1,))
def f4(x, neg):
    print(f'traced argument x = {x}')
    print(f'static argument neg = {neg}')
    return -x if neg else x
 
x=1
f4(x, True)

output:
*traced argument x = Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
static argument neg = True
DeviceArray(-1, dtype=int32, weak_type=True

此時,輸入參數 neg 已經不再是追踪物件了,而是一個靜態常數 True!這個是 Python 偏函式的宣告所造成的效果。

因為靜態輸入參數在追踪時,會被視為常數,因此,若 JIT 函式下回被以「不同值」的靜態輸入參數呼叫時,JAX 就必須再追踪它一次。追踪是要花時間的,反覆的被追踪,那麼此一函式的執行效率將大打折扣。(有一好,必有一壞) 如下例,當我們用 False 這個值再呼叫 f4() 時,print() 的輸出仍然有作用,所以我們知道,又作了一次追踪。

f4(x,False)

output:
traced argument x = Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
static argument neg = False
DeviceArray(1, dtype=int32, weak_type=True)

這種效率上的打折,是我們在使用靜態輸入參數時要考量的,通常我們可以先作些測試,評估使用 JIT 和沒有使用的差別後,再決定程式要怎麼寫。


上一篇
JAX 好好玩 (15) : JAX JIT (4) : 追踪 (Tracing)
下一篇
JAX 好好玩 (17) : JAX JIT (6) : 解密全域變數的怪異行為
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言