(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,[可由此下載]
(https://colab.research.google.com/drive/1tIu9KwFqp7dZ_vCLOZiQ0NK_Y0av0vGF?usp=sharing)
Python 本身並沒有什麼問題,然而搭配 JAX JIT 時,卻有一些行為是超出一般的想像,由其是與「控制流程 (control flow)」相關的程式段。這一回,我們就來看看兩個常見的問題。第一個是「執行時錯誤」,第二個是「迴圈展開」。
在 JIT 函式裏使用 Python 的控制流程指令 if / for / while 是要非常小心的,如果沒有遵循 JIT 函式的注意事項,有時會導致執行錯誤。例如以下的例子[19.1]:
@jit
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
# This will fail!
try:
f(2)
except Exception as e:
print("Exception {}".format(e))
output:
Exception Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with thebool
function.
The error occurred while tracing the function f at :1 for jit. This concrete value was not available in Python because it depends on the value of the argument 'x'.
它違反了原則
另外一個違反同一個原則的例子 [19.1]:
def f(x, n):
y = 0.
for i in range(n):
y = y + x[i]
return y
f = jax.jit(f)
f(jnp.array([2., 3., 4.]), 3)
output:
UnfilteredStackTrace Traceback (most recent call last)
in
8
----> 9 f(jnp.array([2., 3., 4.]), 3)
....
....
TracerIntegerConversionError: The index() method was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
它的錯誤是「不該在控制流程 for 中參考到輸入參數 n 的值」(另一個說法是,函式內的執行流程,不能被輸入參數的值所影響)。這個問題的其中一個解決方法,是將 n 宣告為靜態輸入變數,使得 JAX 不再將其視為追踪物件,這樣,函式就可以正確執行了:
def f2(x, n):
y = 0.
for i in range(n):
y = y + x[i]
return y
f2 = jax.jit(f2, static_argnums=(1,))
f2(jnp.array([2., 3., 4.]), 3)
output:
DeviceArray(9., dtype=float32)
於 JIT 函式裏使用 Python 的 for 迴圈,會產生「迴圈展開 (loop unrolling)」的現象。老頭以下面的例子來解釋這一現象:
def f3(x):
y = 0.
for i in range(5):
y = y + x[i]
return y
f3 = jax.jit(f3)
f3(jnp.array([1., 2., 3., 4., 5.]))
output:
DeviceArray(15., dtype=float32)
函式 f3() 並沒有違反先前提到的原則,執行起來的結果也正確,表面上沒什麼問題,但是如果我們把 f3() 的 JAX 展開式列出來,你就會發現那裏不對勁了!
jax.make_jaxpr(f3)(jnp.array([1.,2., 3., 4.,5.]))
從上圖的註記,大家可以看到 y = y + x[i] 這一行指令被重覆了 5 次 !這種將迴圈內的程式片斷展開成平鋪直述程式段,即是「迴圈展開」。設想,如果我們有一個重覆 100 萬次的迴圈,那麼所有產生的執行碼會多大?這絕對是不能被接受的。為什麼會這樣呢?因為 Python (及其他高階程式語言) 在設計時,主要是以在 CPU 上執行為其考量點,並沒有考慮到這些控制指令是否能在 GPU/TPU 等硬體加速器上執行,JAX 在追踪這一類的 Python 指令時,並無法找到對應的 JAX 表示式,只能將其展開,將迴圈轉變為一行行循序執行的指令序列。
為了解決這些問題,JAX 提出了自己控制流程運算子 (control flow operators),其中重要的是:
接下來,老頭將一一介紹這些運算子。
註:
[19.1] 這些例子來自於 JAX 文件管網 python control flow + JIT