iT邦幫忙

2022 iThome 鐵人賽

DAY 22
0
AI & Data

JAX 好好玩系列 第 22

JAX 好好玩 (22) : 控制流程 (4) : while_loop

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,[可由此下載]
(https://colab.research.google.com/drive/1tIu9KwFqp7dZ_vCLOZiQ0NK_Y0av0vGF?usp=sharing)

JAX while_loop() 運算子的運作方式,基本上沿襲了 Python while 敍述,藉由判斷一個條件式之真偽,來決定迴圈主體是否要繼續執行。

while_loop() 的語法及語義

https://ithelp.ithome.com.tw/upload/images/20220926/20129616IdPxcxNfMn.png

參數 cond_fun 為「條件判斷函式」,其回傳值 (True 或 False) 決定了是否要繼續執行一次 body_fun,其輸入參數:

  • 第一次呼叫時,為 init_val (while_loop 的第三個輸入參數) 。
  • 其後的呼叫,為前一次迴圈 body_fun 的回傳值。

參數 body_fun 為迴圈的主體函式,其輸入參數:

  • 第一次呼叫時,為 init_val (while_loop 的第三個輸入參數) 。
  • 其後的呼叫,為前一次迴圈 body_fun 的回傳值。

其回傳值將做為下一次迴圈的:

  • cond_fun() 輸入參數。
  • body_fun() 輸入參數。
  • 若下一次迴圈的 cond_fun() 為 False 時,其將做為整個 while_loop 的回傳值。

在實作 body_fun() 時要注意,其回傳值的維度 (shape) 和型態 (type) 必須和 init_val 一致,否則 while_loop 將會產生執行時錯誤。當然,如果 init_val 是容器式的資料結構,如 tuple/list/dict 等等,那麼 body_fun 的回傳值也要維持相同的資料結構。

參數 init_val 將做為第一圈執行時 cond_fun 及 body_fun 的輸入參數。

while_loop 的運作流程,可以用以下的 Python 程式段來說明:

def while_loop(cond_fun, body_fun, init_val):
  val = init_val
  while cond_fun(val):
    val = body_fun(val)
  return val

while_loop 在執行時,JAX JIT 會自動編譯 cond_fun 和 body_fun,程式中沒有必要另外宣告它們為 @jax.jit 函式。

while_loop() 的使用範例

init_val = 0
cond_fun = lambda x: x<10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)

output:
*DeviceArray(10, dtype=int32, weak_type=True)

while_loop 是一個相對單純的運算子,老頭就介紹到這裏。


上一篇
JAX 好好玩 (21) : 控制流程 (3) : cond
下一篇
JAX 好好玩 (23) : 控制流程 (5) : switch
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言