iT邦幫忙

2022 iThome 鐵人賽

DAY 19
0
AI & Data

JAX 好好玩系列 第 19

JAX 好好玩 (19) : 控制流程 (1) : Python 的問題

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,[可由此下載]
(https://colab.research.google.com/drive/1tIu9KwFqp7dZ_vCLOZiQ0NK_Y0av0vGF?usp=sharing)

Python 本身並沒有什麼問題,然而搭配 JAX JIT 時,卻有一些行為是超出一般的想像,由其是與「控制流程 (control flow)」相關的程式段。這一回,我們就來看看兩個常見的問題。第一個是「執行時錯誤」,第二個是「迴圈展開」。

執行時錯誤

在 JIT 函式裏使用 Python 的控制流程指令 if / for / while 是要非常小心的,如果沒有遵循 JIT 函式的注意事項,有時會導致執行錯誤。例如以下的例子[19.1]:

@jit
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x
 
# This will fail!
try:
  f(2)
except Exception as e:
  print("Exception {}".format(e))

output:
Exception Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the bool function.
The error occurred while tracing the function f at :1 for jit. This concrete value was not available in Python because it depends on the value of the argument 'x'.

它違反了原則

  • 函式內的控制流程 (control flow) 不要參考輸入參數的值。(參考之前 JIT 的貼文)

另外一個違反同一個原則的例子 [19.1]:

def f(x, n):
    y = 0.
    for i in range(n):
      y = y + x[i]
    return y
 
f = jax.jit(f)
 
f(jnp.array([2., 3., 4.]), 3)

output:
UnfilteredStackTrace Traceback (most recent call last)
in
8
----> 9 f(jnp.array([2., 3., 4.]), 3)
....
....
TracerIntegerConversionError: The index() method was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>

它的錯誤是「不該在控制流程 for 中參考到輸入參數 n 的值」(另一個說法是,函式內的執行流程,不能被輸入參數的值所影響)。這個問題的其中一個解決方法,是將 n 宣告為靜態輸入變數,使得 JAX 不再將其視為追踪物件,這樣,函式就可以正確執行了:

def f2(x, n):
    y = 0.
    for i in range(n):
      y = y + x[i]
    return y
 
f2 = jax.jit(f2, static_argnums=(1,))
 
f2(jnp.array([2., 3., 4.]), 3)

output:
DeviceArray(9., dtype=float32)

迴圈展開:

於 JIT 函式裏使用 Python 的 for 迴圈,會產生「迴圈展開 (loop unrolling)」的現象。老頭以下面的例子來解釋這一現象:

def f3(x):
    y = 0.
    for i in range(5):
      y = y + x[i]
    return y
 
f3 = jax.jit(f3)
 
f3(jnp.array([1., 2., 3., 4., 5.]))

output:
DeviceArray(15., dtype=float32)

函式 f3() 並沒有違反先前提到的原則,執行起來的結果也正確,表面上沒什麼問題,但是如果我們把 f3() 的 JAX 展開式列出來,你就會發現那裏不對勁了!

jax.make_jaxpr(f3)(jnp.array([1.,2., 3., 4.,5.]))

https://ithelp.ithome.com.tw/upload/images/20220926/20129616qn0Suu67sq.png

從上圖的註記,大家可以看到 y = y + x[i] 這一行指令被重覆了 5 次 !這種將迴圈內的程式片斷展開成平鋪直述程式段,即是「迴圈展開」。設想,如果我們有一個重覆 100 萬次的迴圈,那麼所有產生的執行碼會多大?這絕對是不能被接受的。為什麼會這樣呢?因為 Python (及其他高階程式語言) 在設計時,主要是以在 CPU 上執行為其考量點,並沒有考慮到這些控制指令是否能在 GPU/TPU 等硬體加速器上執行,JAX 在追踪這一類的 Python 指令時,並無法找到對應的 JAX 表示式,只能將其展開,將迴圈轉變為一行行循序執行的指令序列。

JAX 控制流程運算子

為了解決這些問題,JAX 提出了自己控制流程運算子 (control flow operators),其中重要的是:

  • cond () : 取代 python if
  • fori_loop () : 取代 python for
  • while_loop () : 取代 python while
  • switch () : 類似 C/C++ 的 switch
  • scan () : JAX 獨特的運算子,與 TensorFlow 的 scan 類似

接下來,老頭將一一介紹這些運算子。

註:

[19.1] 這些例子來自於 JAX 文件管網 python control flow + JIT


上一篇
JAX 好好玩 (18) : JAX JIT (7) : 總結與回顧
下一篇
JAX 好好玩 (20) : 控制流程 (2) : fori_loop
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言