(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,[可由此下載]
(https://colab.research.google.com/drive/1tIu9KwFqp7dZ_vCLOZiQ0NK_Y0av0vGF?usp=sharing)
cond 是 condition 的縮寫,它可以說是 Python if 的 JAX 版本。
參數 pred 它的型態是布林純量,當它為 True 時,true_fun() 將被執行並回傳值;反之,將會回傳 false_fun() 的執行結果。
參數 true_fun 和 false_fun 可以是一般函式 (function) 或是匿名函式 (lambda),它們必須接受相同的輸入參數列。
參數 operands 乃是 true_fun() 及 false_fun() 的輸入參數。
cond 的運作流程,可以用以下的 Python 程式段來說明:
def python_cond(pred, true_fun, false_fun, *operands):
if pred:
return true_fun(*operands)
else:
return false_fun(*operands)
老頭之前的貼文曾經提到過一個造成執行時錯誤的例子:
@jit
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
我們現在就用 cond 來改寫上面的程式,藉以了解 cond 的用法:
true_fun = lambda x : 3. * x **2
false_fun = lambda x : -4. * x
x = 2
print(lax.cond(x<3, true_fun, false_fun, x))
x = 5
print(lax.cond(x<3, true_fun, false_fun, x))
output:
12.0
-20.0
JAX 的「控制流程」是為了要解決 JIT 和 Python 既有的控制流程指令不相容的問題,但是在上面的例子裏,我們並沒有在 true_fun() 和 false_fun() 上面加上 @jax.jit 的修飾,也沒有用 jax.jit() API 去重新定義它們,這是老頭忘記了嗎?
不是的!當執行到 cond() API 的時候,JAX 會自動編譯 true_fun() 和 false_fun() ,在程式中不需要另外宣告。而不管 cond 被第一次呼叫的當時,pred 是 True 或者是 False, JAX 並不區分那一個函式會被第一次執行,true_fun() 和 false_fun() 會被一起編譯。從以下的程式片段,大家可以清楚的看到,對兩個函式的「追踪」都是發生在第一次呼叫的時候。