iT邦幫忙

2022 iThome 鐵人賽

DAY 21
0
AI & Data

JAX 好好玩系列 第 21

JAX 好好玩 (21) : 控制流程 (3) : cond

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,[可由此下載]
(https://colab.research.google.com/drive/1tIu9KwFqp7dZ_vCLOZiQ0NK_Y0av0vGF?usp=sharing)

cond 是 condition 的縮寫,它可以說是 Python if 的 JAX 版本。

cond() 的語法及語義

https://ithelp.ithome.com.tw/upload/images/20220926/20129616NdSSxgomKz.png

參數 pred 它的型態是布林純量,當它為 True 時,true_fun() 將被執行並回傳值;反之,將會回傳 false_fun() 的執行結果。

參數 true_fun 和 false_fun 可以是一般函式 (function) 或是匿名函式 (lambda),它們必須接受相同的輸入參數列。

參數 operands 乃是 true_fun() 及 false_fun() 的輸入參數。

cond 的運作流程,可以用以下的 Python 程式段來說明:

def python_cond(pred, true_fun, false_fun, *operands):
  if pred:
    return true_fun(*operands)
  else:
    return false_fun(*operands)

cond() 的使用範例

老頭之前的貼文曾經提到過一個造成執行時錯誤的例子:

@jit
def f(x):
    if x < 3:
    return 3. * x ** 2
else:
    return -4 * x

我們現在就用 cond 來改寫上面的程式,藉以了解 cond 的用法:

true_fun = lambda x : 3. * x **2
false_fun = lambda x : -4. * x

x = 2
print(lax.cond(x<3, true_fun, false_fun, x))
x = 5
print(lax.cond(x<3, true_fun, false_fun, x))

output:
12.0
-20.0

為什麼不宣告 @jax.jit ?

JAX 的「控制流程」是為了要解決 JIT 和 Python 既有的控制流程指令不相容的問題,但是在上面的例子裏,我們並沒有在 true_fun() 和 false_fun() 上面加上 @jax.jit 的修飾,也沒有用 jax.jit() API 去重新定義它們,這是老頭忘記了嗎?

不是的!當執行到 cond() API 的時候,JAX 會自動編譯 true_fun() 和 false_fun() ,在程式中不需要另外宣告。而不管 cond 被第一次呼叫的當時,pred 是 True 或者是 False, JAX 並不區分那一個函式會被第一次執行,true_fun() 和 false_fun() 會被一起編譯。從以下的程式片段,大家可以清楚的看到,對兩個函式的「追踪」都是發生在第一次呼叫的時候。
https://ithelp.ithome.com.tw/upload/images/20220926/20129616sQ3gATurBn.png


上一篇
JAX 好好玩 (20) : 控制流程 (2) : fori_loop
下一篇
JAX 好好玩 (22) : 控制流程 (4) : while_loop
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言