iT邦幫忙

2022 iThome 鐵人賽

DAY 25
0
AI & Data

JAX 好好玩系列 第 25

JAX 好好玩 (25) : 控制流程 (7) : 總結

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,[可由此下載]
(https://colab.research.google.com/drive/1tIu9KwFqp7dZ_vCLOZiQ0NK_Y0av0vGF?usp=sharing)

老頭已經介紹了 JAX 提供的 5 個控制流程,事實上,除了這 5 個之外,JAX 還有 map 和 associative_scan 另外兩個控制流程。因為它們語法和語義相當單純,老頭就不在這裏多談了,有興趣的讀者可以直接去看 JAX 的官方文件 [25.1]。要注意的是 map,JAX 官網建議我們儘量使用 vmap,老頭很快的就會為大家介紹 vmap。

最後用一張表做為控制流程的總結 [25.2]:
https://ithelp.ithome.com.tw/upload/images/20220926/201296167yKRNCEaa8.png

先看 jit 那一欄。Python 的 if 是不建議使用在 JIT 函式內,而 for 及 while 可以被有限度的使用,但是要注意:

  • 它們會有「迴圈展開」的問題,你要評估一下系統的記憶體夠不夠展開這些迴圈。
  • 在迴圈的條件式裏,不可以參考到輸入參數的值,否則執行時會出現錯誤。

再來看 grad 這一欄。grad 是 JAX 的 Auto Diff 功能的 API 名稱,這一欄說明了這些指令及運算子是否能夠支援 Auto Diff。雖然老頭還沒有介紹 Auto Diff,但是以下幾點,仍希望大家先放在心裏,未來在 Auto Diff 的說明中,會有更清楚的解釋。

  • Auto Diff 有順向模式 (forward-mode) 和逆向模式 (reverse-mode) 兩種。
  • while_loop 僅支援順向模式。
  • fori_loop 支援順向模式,且有條件的支援逆向模式。
  • 表上其他的指令及運算子能完全支援 Auto Diff 兩種模式。

那麼 fori_loop 在那些條件下才能支援逆向模式呢?其實 JAX 在處理 fori_loop 時,會將其轉成 while_loop 或是 scan,當它被轉成 while_loop 時,就只支援順向模式了。若是它被轉成 scan ,那麼就可以同時支援逆向及順向模式了。

fori_loop 被轉成 scan 的條件是:當其被追踪時,若是迴圈被執行的次數能夠被決定,就可以被轉成 scan。

例如:

fori_loop(1, 10, my_body, my_argument)

這行指令在被追踪時,已經可以知道它的重覆次數是 9 次,所以可以轉成 scan。
而下面這行指令:

fori_loop(x, y, my_body, my_argument)

在追踪時,x 和 y 都會用「追踪物件」來表示,並不會直接參考它們的值,也就是說,迴圈的重覆次數不能在追踪時被決定,因此 JAX 會用 while_loop 來實踐這行指令。

對於控制流程的介紹就在這告一段落,接下來,我們就要進入 Auto Diff。

註:

[25.1] map 可參考 這裏,associative_scan 可參考這裏

[25.2] 參考 JAX 官網文件 (https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#summary)


上一篇
JAX 好好玩 (24) : 控制流程 (6) : scan
下一篇
JAX 好好玩 (26) : Auto Diff (1) : grad 簡介
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言