(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載
上一次的貼文提到以上這個例子,它在執行時會報錯。
@jax.jit
def f2(x, neg):
return -x if neg else x # will raise run-time error
f2(1, True)
理由是在 JIT 函式中,程式的控制邏輯不能取決於「追踪物件」的值。然而,如果在我們的程式中必須有這樣的設計時,應該怎麼辦呢 ? 一個解決方法是,讓這個輸入參數「不要被追踪」。我們可以利用 Python 的偏函式 (partial) ,將 neg 宣告為「靜態輸入參數」:
from functools import partial
@partial(jax.jit, static_argnums=(1,))
def f3(x, neg):
return -x if neg else x
x=1
f3(1, True)
output:
DeviceArray(-1, dtype=int32)
這樣,執行時就不會報錯了。更進一步來看靜態輸入參數,我們把 f3() 改寫成 f4(),刻意印出它的兩個參數:
@partial(jax.jit, static_argnums=(1,))
def f4(x, neg):
print(f'traced argument x = {x}')
print(f'static argument neg = {neg}')
return -x if neg else x
x=1
f4(x, True)
output:
*traced argument x = Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
static argument neg = True
DeviceArray(-1, dtype=int32, weak_type=True
此時,輸入參數 neg 已經不再是追踪物件了,而是一個靜態常數 True!這個是 Python 偏函式的宣告所造成的效果。
因為靜態輸入參數在追踪時,會被視為常數,因此,若 JIT 函式下回被以「不同值」的靜態輸入參數呼叫時,JAX 就必須再追踪它一次。追踪是要花時間的,反覆的被追踪,那麼此一函式的執行效率將大打折扣。(有一好,必有一壞) 如下例,當我們用 False 這個值再呼叫 f4() 時,print() 的輸出仍然有作用,所以我們知道,又作了一次追踪。
f4(x,False)
output:
traced argument x = Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
static argument neg = False
DeviceArray(1, dtype=int32, weak_type=True)
這種效率上的打折,是我們在使用靜態輸入參數時要考量的,通常我們可以先作些測試,評估使用 JIT 和沒有使用的差別後,再決定程式要怎麼寫。