iT邦幫忙

2022 iThome 鐵人賽

DAY 23
0
AI & Data

JAX 好好玩系列 第 23

JAX 好好玩 (23) : 控制流程 (5) : switch

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,[可由此下載]
(https://colab.research.google.com/drive/1tIu9KwFqp7dZ_vCLOZiQ0NK_Y0av0vGF?usp=sharing)

習慣使用 C/C++ switch/case 指令的程式工程師,對於 Python 中的 if-elif-else 的結構,多少都會覺得不順眼。雖然 Python 並沒有提供 switch,JAX 在設計的時候,貼心地把 switch 包含進來。

註:在 Python 3.10 版開始,提出了 match-case 控制流程,基本上它就是 Python 裏的 switch/case。

switch() 的語法及語義

https://ithelp.ithome.com.tw/upload/images/20220926/20129616vnYEmPH9tW.png

JAX switch 是一個相當簡化的 switch 控制流程,它接受一個索引 (index),一個函式序列 (branches) 及一組輸入參數 (*operands),依據索引值,決定呼叫函式序列中的那一個函式,呼叫時以 operands 為其輸入參數。

若是索引值超過範圍,則呼叫序列中最後一個函式。負的索引值則呼叫序列中第一個函式。(也就是相當於對輸入的 index 做所謂的 clamp 操作,參考下面的 clamp 函式。)

switch 在第一次被呼叫的時候,所有在序列裏的函式都會被追踪及編譯,這是 switch 的既定動作,我們無須宣告(修飾)這些函式為 @jax.jit。

switch 的運作流程,可以用以下的 Python 程式段來說明:

def clamp(n, smallest, largest):
    return max(smallest, min(n, largest))

def switch(index, branches, *operands):
    index = clamp(index, 0, len(branches) - 1)
    return branches[index](*operands)

在這也簡單地定義了 clamp 操作,供讀者們參考。

switch() 的使用範例

首先定義好函式及函式序列,也定義好函式的共用輸入參數:

def get_zero_samples(extra):
    print(f'In zero, you will receive {0+extra} samples.')
    return 0+extra

def get_one_samples(extra):
    print(f'In one, you will receive {1+extra} samples.')
    return 1+extra

def get_two_samples(extra):
    print(f'In two, you will receive {2+extra} samples.')
    return 2+extra

def get_three_samples(extra):
    print(f'In three, you will receive {3+extra} samples.')
    return 3+extra

fun_list = [get_zero_samples, get_one_samples, get_two_samples, get_three_samples]
extra = 10

第一次呼叫 switch:

# normal case
index = 1
lax.switch(index, fun_list, extra) 

output:
https://ithelp.ithome.com.tw/upload/images/20220926/20129616UXJr0BkK0l.png

從第一次呼叫的結果可以看出來,儘管只有一個函式會被實際執行,但所有的函式都會被追踪及編譯。

再來實驗索引超過範圍的情況:

# out of bond 1
index = 5
lax.switch(index, fun_list, extra) 

output:
DeviceArray(13, dtype=int32, weak_type=True)

# out of bond 2
index = -2
lax.switch(index, fun_list, extra) 

output:
DeviceArray(10, dtype=int32, weak_type=True)

可以明顯的看出 clamp 的結果。


上一篇
JAX 好好玩 (22) : 控制流程 (4) : while_loop
下一篇
JAX 好好玩 (24) : 控制流程 (6) : scan
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言