(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載
是時候來總結一下 JAX JIT 了 [18.1]。
JAX JIT 的運作,首先即是「追踪」函式中的 Python 原始碼,追踪的過程會執行以下重要的動作:
將輸入參數轉成追踪物件 (tracer object)
除非程式設計時特別把某些參數宣告為「靜態輸入參數」,否則 JAX 一律以追踪物件來處理輸入參數。JAX 會記錄追踪物件的「維度 (shape)」和「型態 (type)」,在追踪的過程中,JAX 並不會參考追踪物件的「值 (value) 」。
在追踪的過程中,對於「靜態輸入參數」,JAX 會將其視為「常數」。
追踪函式內的每一條 Python 敍述 (Python statement)
追踪時,不管它們有沒有副作用,每一條敍述都會被追踪並執行一次。JAX 會記錄那些和追踪物件有關的運算,並利用最佳化的方法,將這些運算轉成 JAX 表示式;而那些副作用及 print 等敍述,將直接被捨棄,並不會被放入表示式。
「追踪」所產生的 JAX 表示式,將被編譯成可以在 CPU/GPU/TPU 上運作的可執行碼,JAX 延用在 TensorFlow 上的 XLA (Accelerated Linear Algebra; 加速線性代數) 編譯環境。因為 XLA 是下層的細節,老頭並沒有計劃在「JAX 好好玩」這一個專題介紹它,有興趣的讀者可以參考 TensorFlow XLA 的相關網頁 [18.2]
編譯好的可執行碼,會被 JAX 暫存(caching)起來,下一回當函式再一次被呼叫時,JAX 會執行下列的檢查,以決定是否可以直接使用暫存的可執行碼,或是需要另外再追踪一次:
當然,程式執行一段時間後,某一個 JIT 函式可能會有許多個暫存可執行碼,新的呼叫會檢查所有的暫存碼,找到匹配的那一個。
前面的貼文,老頭介紹過了撰寫 JIT 函式要注意的地方,整理一下給大家參考:
對於 JIT 的初步介紹,就在此告一段落。其實有一個要注意的地方老頭並沒有講得很清楚,它是有關 Python 流程控制指令 (if, while, for) 。除了「函式內的控制流程不要參考輸入參數的值」之外,在 JIT 函式裏使用這些流程控制指令也有許多要注意的地方,接下來的貼文,老頭就要和大家討論和 JIT 息息相關的「控制流程 (control flow) 」。
註:
[18.1] 本文主要參考了 JAX 文件官網 「Just In Time Compilation with JAX」” 一文的內容。
[18.2] 可參考 XLA :機器學習的最佳化編譯器 。