iT邦幫忙

2022 iThome 鐵人賽

DAY 20
0
AI & Data

JAX 好好玩系列 第 20

JAX 好好玩 (20) : 控制流程 (2) : fori_loop

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,[可由此下載]
(https://colab.research.google.com/drive/1tIu9KwFqp7dZ_vCLOZiQ0NK_Y0av0vGF?usp=sharing)

從這個運算子 (operator) 的名字就可以看出來,它和 Python 中的 for 敍述有密切的關係。是的,它就是 for 的 JAX 版本!這一回老頭會先介紹 fori_loop() 的語法及語義,舉幾個例子,而後,再來看看它是否解決了 Python 迴圈展開的問題。

fori_loop() 的語法及語義

https://ithelp.ithome.com.tw/upload/images/20220926/20129616VioPYYVl8T.png

首先要注意的,JAX 控制流程運算子都是定義在 jax.lax 封裝 (package) 裏。事實上,所有的 JAX 流程控制都是定義在這個封裝之下。

lowupper 這兩個整數參數的用法和 Python 的 for 一樣,界定的迴圈的下限 (包含) 及上限 (不包含),而目前 fori_loop 只提供步幅 (step) 為 1 的運算。

body_fun 是 fori_loop 迴圈執行時的主體,它可以是一般函式 (function) 或是匿名函式 (lambda),它接受兩個參數:

  • 第一個是迴圈的索引值 (整數,從 lowerupper - 1) 。
  • 第二個參數 (被稱為 carry value)
    -- 第一圈時,為 init_val 。
    -- 其後為前一圈的回傳值。

在實作 body_fun 時要注意,它的回傳值必須保持和 init_val 相同的維度 (shape) 和型態 (type),如果 init_val 是容器式的資料結構,如 tuple/list/dict 等等,那麼 body_fun 的回傳值也要維持相同的資料結構。簡單講,body_fun 的回傳值要保持和 init_val 相同的資料結構、型態、維度

fori_loop 的運作流程 (也就是其語義) ,可以用以下的 Python 程式段來說明:

def fori_loop(lower, upper, body_fun, init_val):
    val = init_val
    for i in range(lower, upper):
        val = body_fun(i, val)
    return val

一些例子

1. 以匿名函式作為 body_fun():

init_val = 0
start = 0
stop = 10

body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)

output:
*DeviceArray(45, dtype=int32, weak_type=True)

2. 改寫之前「迴圈展開」的例子 (計算 [1.,2.,3.,4.,5.] 內所有元素的加總)

def no_unrolling(x, num):

    init_val = 0.
    start = 0
    stop = num
    
    def body_fun(i, carry):
      return carry + x[i]

    return lax.fori_loop(start, stop, body_fun, init_val)

no_unrolling(jnp.array([1.,2.,3.,4.,5.]), 5)

output:
*DeviceArray(15., dtype=float32)

讀者們可以將這段程式和上一回的「迴圈展開」例子做個對照,未來如果有需要將 Python for 程式改寫為 JAX fori_loop 時,可以參考這二個對照的例子,就知道該怎麼做了。

檢查有沒有迴圈展開

我們先印出含有 5 個迴圈的 JAX 表示式。
https://ithelp.ithome.com.tw/upload/images/20220926/20129616qGC1uhwZ7D.png

再印出含有 10000 個迴圈的表示式。
https://ithelp.ithome.com.tw/upload/images/20220926/201296168l4hPuWOKF.png

儘管目前大家不見得能夠完全了解 JAX 表示式,但是將這兩個結果對照起來,它們的表示式長度並沒有明顯的差異。我們應該能夠確定 fori_loop() 並不會發生迴圈展開的現象。


上一篇
JAX 好好玩 (19) : 控制流程 (1) : Python 的問題
下一篇
JAX 好好玩 (21) : 控制流程 (3) : cond
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言