iT邦幫忙

2022 iThome 鐵人賽

DAY 15
0
AI & Data

JAX 好好玩系列 第 15

JAX 好好玩 (15) : JAX JIT (4) : 追踪 (Tracing)

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載

JIT … work by tracing a function to determine its effect on inputs of a specific shape and type. … [15.1]

在 JIT 的機制之下,當函式第一次被執行時,JAX 會先「追踪 (trace)」這個函式,並產生一個中間階段碼,稱之為「JAX 表示式 (JAX Expression, jaxpr)」接著,JAX 編譯器再將 JAX 表示式轉換為可執行碼。
https://ithelp.ithome.com.tw/upload/images/20220919/20129616P0PlwNTA5P.png

現在,就讓我們深入探討「追踪」在做些什麼事。

要點一:在追踪時,會記下此次追踪函式輸入參數的維度 (shape) 及型態 (type),並且用追踪物件 (tracer object) 來表示輸入參數及函式運算過程中所涉及的其他變數。

「追踪」可以說是「針對執行時輸入參數的維度及型態,分析這個函式要做的動作,以及其想到達成的結果,產生可以有效率執行的 JAX 表示式。」在追踪時,JAX 會把輸入參數轉換為「追踪物件 (tracer object)」,同時在函式運算過程中所使用的暫時變數,也都會用追踪物件來表示。

我們故意在以下的 JIT 函式中加入 print(),藉以看看這些追踪物件:

@jax.jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.matmul(x + 1, y + 1)
  print(f"  result = {result}")
  print("=======================================")
  return result
 
key = jax.random.PRNGKey(0)
key, key1, key2 = jax.random.split(key,num=3)
 
x = jax.random.normal(key1, (3, 4))
y = jax.random.normal(key2, (4,))
f(x, y)

output:
*Running f():
x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=0/1)>
y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=0/1)>
result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=0/1)>
'======================================='
DeviceArray([1.4153224, 2.1780725, 2.5720434], dtype=float32)

JIT 函式 f() 的輸入參數 x 和 y 在追踪時,分別被表示為

x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=0/1)>
y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=0/1)>

這兩個追踪物件。而運算過程中產生的 result,也被表示為追踪物件:

result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=0/1)>

要點二:在追踪的過程中,函式內所有的 Python 指令及敍述,都會被執行一次。

追踪的目的之一,是要分析函式的運算流程,來生成 JAX 表示式。因此,函式內所有的 Python 敍述都會被執行一次。

要點三:在追踪後產生的 JAX 表示式,不會包括像 print() 等有副作用的 Python 指令。

JAX 表示式,經編譯後,要能夠跑在 GPU 及 TPU 等硬體加速器上,它不能包括 print() 等有副作用的敍述。

要點四:當函式再次被執行時,如果此次的輸入參數的維度和型態不變,和前次相同,那麼就無須再次追踪了,可以直接跑可執行碼。

再執行 JIT 函式 f() 一次:

# to call f(x,y) 2nd time
f(x,y)

output:
*DeviceArray([1.4153224, 2.1780725, 2.5720434], dtype=float32)

這次,是直接執行編譯後的執行碼,因此所有 print() 的訊息都不會再被輸出了,因為它們並不在 JAX 表示式中。

要點五:當函式再次被執行時,如果此次的輸入參數的維度和型態和前次不相同,JAX 會「再追踪」一次。

老頭之前曾說,函式在第一次執行時會被追踪,這只涵蓋了一小部份的狀況,再精確的說,當一個 JIT 函式要執行時,如果之前此函式,針對相同維度及型態的輸入參數,曾被追踪過,那麼不必再追踪一次了。否則,會被再追踪一次,並記錄下來此次的輸入參數的維度和型態。

# declare new input parameter with different shapes
 
x1 = jax.random.normal(key1, (4, 5))
y1 = jax.random.normal(key2, (5,))
f(x1, y1)

output:
Running f():
x = Traced<ShapedArray(float32[4,5])>with<DynamicJaxprTrace(level=0/1)>
y = Traced<ShapedArray(float32[5])>with<DynamicJaxprTrace(level=0/1)>
result = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=0/1)>
'======================================='
DeviceArray([-0.94372165, 1.2157216 , 6.2118015 , 6.090805 ], dtype=float32)

從上面的程式片斷我們可以看到,如果輸入參數的維度和之前不同時,呼叫 f() 會導致再一次的追踪。因此 print() 的訊息也再一次被輸出了。

要點六:因為 JAX 表示式中的程式控制流程 (control flow),不能參考到追踪物件的值。

所謂的控制流程,是指 Python 裏 if / for / while 等流程控制指令。舉個例子:

@jax.jit
def f2(x, neg):
  return -x if neg else x # will raise run-time error
 
f2(1, True)

output:
'---------------------------------------------------------------------------'
*UnfilteredStackTrace Traceback (most recent call last)
in
4
----> 5 f2(1, True)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: >Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the bool function.
The error occurred while tracing the function f at :1 for jit. >This concrete value was not available in Python because it depends on the value of the argument >'neg'.

這個程式想要判斷輸入參數 neg 的值,來決定回傳 -x 或 x。很明顯的違反要點六,在函式執行時會報錯。

要點七:最後,如果我們想要知道某一個函式它的 JAX 表示式,從而事先了解它在執行時的效率表現,可以使用 JAX 提供的 make_jaxpr() API :

jax.make_jaxpr(f)(x,y)

output:
https://ithelp.ithome.com.tw/upload/images/20220919/2012961697SpJomOE1.png

在這裏只是給大家一個初步的概念,知道 JAX 表示式大概是什麼樣子。從中可以看出來,它只是在執行 jnp.matmul(x + 1, y + 1) 計算而已,並沒有包含 print()。

[15.1] 本文主要參考了 JAX 文件官網 JIT mechanics: tracing and static variables 一文的內容。


上一篇
JAX 好好玩 (14) : JAX JIT (3) : 函式內陣列的維度
下一篇
JAX 好好玩 (16) : JAX JIT (5) : 如何不被追踪
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言