(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,[可由此下載]
(https://colab.research.google.com/drive/1tIu9KwFqp7dZ_vCLOZiQ0NK_Y0av0vGF?usp=sharing)
習慣使用 C/C++ switch/case 指令的程式工程師,對於 Python 中的 if-elif-else 的結構,多少都會覺得不順眼。雖然 Python 並沒有提供 switch,JAX 在設計的時候,貼心地把 switch 包含進來。
(註:在 Python 3.10 版開始,提出了 match-case 控制流程,基本上它就是 Python 裏的 switch/case。)
JAX switch 是一個相當簡化的 switch 控制流程,它接受一個索引 (index),一個函式序列 (branches) 及一組輸入參數 (*operands),依據索引值,決定呼叫函式序列中的那一個函式,呼叫時以 operands 為其輸入參數。
若是索引值超過範圍,則呼叫序列中最後一個函式。負的索引值則呼叫序列中第一個函式。(也就是相當於對輸入的 index 做所謂的 clamp 操作,參考下面的 clamp 函式。)
switch 在第一次被呼叫的時候,所有在序列裏的函式都會被追踪及編譯,這是 switch 的既定動作,我們無須宣告(修飾)這些函式為 @jax.jit。
switch 的運作流程,可以用以下的 Python 程式段來說明:
def clamp(n, smallest, largest):
return max(smallest, min(n, largest))
def switch(index, branches, *operands):
index = clamp(index, 0, len(branches) - 1)
return branches[index](*operands)
在這也簡單地定義了 clamp 操作,供讀者們參考。
首先定義好函式及函式序列,也定義好函式的共用輸入參數:
def get_zero_samples(extra):
print(f'In zero, you will receive {0+extra} samples.')
return 0+extra
def get_one_samples(extra):
print(f'In one, you will receive {1+extra} samples.')
return 1+extra
def get_two_samples(extra):
print(f'In two, you will receive {2+extra} samples.')
return 2+extra
def get_three_samples(extra):
print(f'In three, you will receive {3+extra} samples.')
return 3+extra
fun_list = [get_zero_samples, get_one_samples, get_two_samples, get_three_samples]
extra = 10
第一次呼叫 switch:
# normal case
index = 1
lax.switch(index, fun_list, extra)
output:
從第一次呼叫的結果可以看出來,儘管只有一個函式會被實際執行,但所有的函式都會被追踪及編譯。
再來實驗索引超過範圍的情況:
# out of bond 1
index = 5
lax.switch(index, fun_list, extra)
output:
DeviceArray(13, dtype=int32, weak_type=True)
# out of bond 2
index = -2
lax.switch(index, fun_list, extra)
output:
DeviceArray(10, dtype=int32, weak_type=True)
可以明顯的看出 clamp 的結果。