(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,[可由此下載]
(https://colab.research.google.com/drive/1tIu9KwFqp7dZ_vCLOZiQ0NK_Y0av0vGF?usp=sharing)
JAX while_loop() 運算子的運作方式,基本上沿襲了 Python while 敍述,藉由判斷一個條件式之真偽,來決定迴圈主體是否要繼續執行。
參數 cond_fun 為「條件判斷函式」,其回傳值 (True 或 False) 決定了是否要繼續執行一次 body_fun,其輸入參數:
參數 body_fun 為迴圈的主體函式,其輸入參數:
其回傳值將做為下一次迴圈的:
在實作 body_fun() 時要注意,其回傳值的維度 (shape) 和型態 (type) 必須和 init_val 一致,否則 while_loop 將會產生執行時錯誤。當然,如果 init_val 是容器式的資料結構,如 tuple/list/dict 等等,那麼 body_fun 的回傳值也要維持相同的資料結構。
參數 init_val 將做為第一圈執行時 cond_fun 及 body_fun 的輸入參數。
while_loop 的運作流程,可以用以下的 Python 程式段來說明:
def while_loop(cond_fun, body_fun, init_val):
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
while_loop 在執行時,JAX JIT 會自動編譯 cond_fun 和 body_fun,程式中沒有必要另外宣告它們為 @jax.jit 函式。
init_val = 0
cond_fun = lambda x: x<10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
output:
*DeviceArray(10, dtype=int32, weak_type=True)
while_loop 是一個相對單純的運算子,老頭就介紹到這裏。