(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,[可由此下載]
(https://colab.research.google.com/drive/1tIu9KwFqp7dZ_vCLOZiQ0NK_Y0av0vGF?usp=sharing)
從這個運算子 (operator) 的名字就可以看出來,它和 Python 中的 for 敍述有密切的關係。是的,它就是 for 的 JAX 版本!這一回老頭會先介紹 fori_loop() 的語法及語義,舉幾個例子,而後,再來看看它是否解決了 Python 迴圈展開的問題。
首先要注意的,JAX 控制流程運算子都是定義在 jax.lax 封裝 (package) 裏。事實上,所有的 JAX 流程控制都是定義在這個封裝之下。
low 和 upper 這兩個整數參數的用法和 Python 的 for 一樣,界定的迴圈的下限 (包含) 及上限 (不包含),而目前 fori_loop 只提供步幅 (step) 為 1 的運算。
body_fun 是 fori_loop 迴圈執行時的主體,它可以是一般函式 (function) 或是匿名函式 (lambda),它接受兩個參數:
在實作 body_fun 時要注意,它的回傳值必須保持和 init_val 相同的維度 (shape) 和型態 (type),如果 init_val 是容器式的資料結構,如 tuple/list/dict 等等,那麼 body_fun 的回傳值也要維持相同的資料結構。簡單講,body_fun 的回傳值要保持和 init_val 相同的資料結構、型態、維度。
fori_loop 的運作流程 (也就是其語義) ,可以用以下的 Python 程式段來說明:
def fori_loop(lower, upper, body_fun, init_val):
val = init_val
for i in range(lower, upper):
val = body_fun(i, val)
return val
1. 以匿名函式作為 body_fun():
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
output:
*DeviceArray(45, dtype=int32, weak_type=True)
2. 改寫之前「迴圈展開」的例子 (計算 [1.,2.,3.,4.,5.] 內所有元素的加總)
def no_unrolling(x, num):
init_val = 0.
start = 0
stop = num
def body_fun(i, carry):
return carry + x[i]
return lax.fori_loop(start, stop, body_fun, init_val)
no_unrolling(jnp.array([1.,2.,3.,4.,5.]), 5)
output:
*DeviceArray(15., dtype=float32)
讀者們可以將這段程式和上一回的「迴圈展開」例子做個對照,未來如果有需要將 Python for 程式改寫為 JAX fori_loop 時,可以參考這二個對照的例子,就知道該怎麼做了。
我們先印出含有 5 個迴圈的 JAX 表示式。
再印出含有 10000 個迴圈的表示式。
儘管目前大家不見得能夠完全了解 JAX 表示式,但是將這兩個結果對照起來,它們的表示式長度並沒有明顯的差異。我們應該能夠確定 fori_loop() 並不會發生迴圈展開的現象。