iT邦幫忙

2022 iThome 鐵人賽

DAY 18
0
AI & Data

JAX 好好玩系列 第 18

JAX 好好玩 (18) : JAX JIT (7) : 總結與回顧

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載

是時候來總結一下 JAX JIT 了 [18.1]。

追踪的過程

https://ithelp.ithome.com.tw/upload/images/20220919/20129616Dg1d2zV2Z1.png

JAX JIT 的運作,首先即是「追踪」函式中的 Python 原始碼,追踪的過程會執行以下重要的動作:

  • 將輸入參數轉成追踪物件 (tracer object)

    除非程式設計時特別把某些參數宣告為「靜態輸入參數」,否則 JAX 一律以追踪物件來處理輸入參數。JAX 會記錄追踪物件的「維度 (shape)」和「型態 (type)」,在追踪的過程中,JAX 並不會參考追踪物件的「值 (value) 」。

    在追踪的過程中,對於「靜態輸入參數」,JAX 會將其視為「常數」。

  • 追踪函式內的每一條 Python 敍述 (Python statement)

    追踪時,不管它們有沒有副作用,每一條敍述都會被追踪並執行一次。JAX 會記錄那些和追踪物件有關的運算,並利用最佳化的方法,將這些運算轉成 JAX 表示式;而那些副作用及 print 等敍述,將直接被捨棄,並不會被放入表示式。

「追踪」所產生的 JAX 表示式,將被編譯成可以在 CPU/GPU/TPU 上運作的可執行碼,JAX 延用在 TensorFlow 上的 XLA (Accelerated Linear Algebra; 加速線性代數) 編譯環境。因為 XLA 是下層的細節,老頭並沒有計劃在「JAX 好好玩」這一個專題介紹它,有興趣的讀者可以參考 TensorFlow XLA 的相關網頁 [18.2]

編譯好的可執行碼,會被 JAX 暫存(caching)起來,下一回當函式再一次被呼叫時,JAX 會執行下列的檢查,以決定是否可以直接使用暫存的可執行碼,或是需要另外再追踪一次:

  • 對於「靜態輸入參數」,檢查此次呼叫的「值 (value)」是否和暫存的可執行碼一致。
  • 對於一般輸入參數,檢查此次呼叫的維度和型態,是否和暫存的一致。

當然,程式執行一段時間後,某一個 JIT 函式可能會有許多個暫存可執行碼,新的呼叫會檢查所有的暫存碼,找到匹配的那一個。

JIT 函式注意事項:

前面的貼文,老頭介紹過了撰寫 JIT 函式要注意的地方,整理一下給大家參考:

  • 儘量遵循純函式的規則,否則你必須清楚知道函式中的副作用會造成的結果。
  • 函式內的控制流程 (control flow) 不要參考輸入參數的值。
  • 函式內所涉及的陣列維度,在輸入參數維度決定後,它們都必須是決定的。
  • 對於不想被追踪的輸入變數,可以宣告其為「靜態輸入參數」,但如果這個參數的值在每次呼叫函式的時候都會變化,將會造成函式每次都被追踪,降低執行效率。

對於 JIT 的初步介紹,就在此告一段落。其實有一個要注意的地方老頭並沒有講得很清楚,它是有關 Python 流程控制指令 (if, while, for) 。除了「函式內的控制流程不要參考輸入參數的值」之外,在 JIT 函式裏使用這些流程控制指令也有許多要注意的地方,接下來的貼文,老頭就要和大家討論和 JIT 息息相關的「控制流程 (control flow) 」。

註:

[18.1] 本文主要參考了 JAX 文件官網 「Just In Time Compilation with JAX」” 一文的內容。

[18.2] 可參考 XLA :機器學習的最佳化編譯器


上一篇
JAX 好好玩 (17) : JAX JIT (6) : 解密全域變數的怪異行為
下一篇
JAX 好好玩 (19) : 控制流程 (1) : Python 的問題
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言