(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載
JIT … work by tracing a function to determine its effect on inputs of a specific shape and type. … [15.1]
在 JIT 的機制之下,當函式第一次被執行時,JAX 會先「追踪 (trace)」這個函式,並產生一個中間階段碼,稱之為「JAX 表示式 (JAX Expression, jaxpr)」接著,JAX 編譯器再將 JAX 表示式轉換為可執行碼。
現在,就讓我們深入探討「追踪」在做些什麼事。
要點一:在追踪時,會記下此次追踪函式輸入參數的維度 (shape) 及型態 (type),並且用追踪物件 (tracer object) 來表示輸入參數及函式運算過程中所涉及的其他變數。
「追踪」可以說是「針對執行時輸入參數的維度及型態,分析這個函式要做的動作,以及其想到達成的結果,產生可以有效率執行的 JAX 表示式。」在追踪時,JAX 會把輸入參數轉換為「追踪物件 (tracer object)」,同時在函式運算過程中所使用的暫時變數,也都會用追踪物件來表示。
我們故意在以下的 JIT 函式中加入 print(),藉以看看這些追踪物件:
@jax.jit
def f(x, y):
print("Running f():")
print(f" x = {x}")
print(f" y = {y}")
result = jnp.matmul(x + 1, y + 1)
print(f" result = {result}")
print("=======================================")
return result
key = jax.random.PRNGKey(0)
key, key1, key2 = jax.random.split(key,num=3)
x = jax.random.normal(key1, (3, 4))
y = jax.random.normal(key2, (4,))
f(x, y)
output:
*Running f():
x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=0/1)>
y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=0/1)>
result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=0/1)>
'======================================='
DeviceArray([1.4153224, 2.1780725, 2.5720434], dtype=float32)
JIT 函式 f() 的輸入參數 x 和 y 在追踪時,分別被表示為
x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=0/1)>
y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=0/1)>
這兩個追踪物件。而運算過程中產生的 result,也被表示為追踪物件:
result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=0/1)>
要點二:在追踪的過程中,函式內所有的 Python 指令及敍述,都會被執行一次。
追踪的目的之一,是要分析函式的運算流程,來生成 JAX 表示式。因此,函式內所有的 Python 敍述都會被執行一次。
要點三:在追踪後產生的 JAX 表示式,不會包括像 print() 等有副作用的 Python 指令。
JAX 表示式,經編譯後,要能夠跑在 GPU 及 TPU 等硬體加速器上,它不能包括 print() 等有副作用的敍述。
要點四:當函式再次被執行時,如果此次的輸入參數的維度和型態不變,和前次相同,那麼就無須再次追踪了,可以直接跑可執行碼。
再執行 JIT 函式 f() 一次:
# to call f(x,y) 2nd time
f(x,y)
output:
*DeviceArray([1.4153224, 2.1780725, 2.5720434], dtype=float32)
這次,是直接執行編譯後的執行碼,因此所有 print() 的訊息都不會再被輸出了,因為它們並不在 JAX 表示式中。
要點五:當函式再次被執行時,如果此次的輸入參數的維度和型態和前次不相同,JAX 會「再追踪」一次。
老頭之前曾說,函式在第一次執行時會被追踪,這只涵蓋了一小部份的狀況,再精確的說,當一個 JIT 函式要執行時,如果之前此函式,針對相同維度及型態的輸入參數,曾被追踪過,那麼不必再追踪一次了。否則,會被再追踪一次,並記錄下來此次的輸入參數的維度和型態。
# declare new input parameter with different shapes
x1 = jax.random.normal(key1, (4, 5))
y1 = jax.random.normal(key2, (5,))
f(x1, y1)
output:
Running f():
x = Traced<ShapedArray(float32[4,5])>with<DynamicJaxprTrace(level=0/1)>
y = Traced<ShapedArray(float32[5])>with<DynamicJaxprTrace(level=0/1)>
result = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=0/1)>
'======================================='
DeviceArray([-0.94372165, 1.2157216 , 6.2118015 , 6.090805 ], dtype=float32)
從上面的程式片斷我們可以看到,如果輸入參數的維度和之前不同時,呼叫 f() 會導致再一次的追踪。因此 print() 的訊息也再一次被輸出了。
要點六:因為 JAX 表示式中的程式控制流程 (control flow),不能參考到追踪物件的值。
所謂的控制流程,是指 Python 裏 if / for / while 等流程控制指令。舉個例子:
@jax.jit
def f2(x, neg):
return -x if neg else x # will raise run-time error
f2(1, True)
output:
'---------------------------------------------------------------------------'
*UnfilteredStackTrace Traceback (most recent call last)
in
4
----> 5 f2(1, True)
…
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: >Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with thebool
function.
The error occurred while tracing the function f at :1 for jit. >This concrete value was not available in Python because it depends on the value of the argument >'neg'.
這個程式想要判斷輸入參數 neg 的值,來決定回傳 -x 或 x。很明顯的違反要點六,在函式執行時會報錯。
要點七:最後,如果我們想要知道某一個函式它的 JAX 表示式,從而事先了解它在執行時的效率表現,可以使用 JAX 提供的 make_jaxpr() API :
jax.make_jaxpr(f)(x,y)
output:
在這裏只是給大家一個初步的概念,知道 JAX 表示式大概是什麼樣子。從中可以看出來,它只是在執行 jnp.matmul(x + 1, y + 1) 計算而已,並沒有包含 print()。
[15.1] 本文主要參考了 JAX 文件官網 JIT mechanics: tracing and static variables 一文的內容。