還記得老頭最初接觸 JAX 的前一個月,雖然從各類的文獻中讀到有關 JAX 的定義、它的組成元件、和它的某些應用,但總是摸不到對於 JAX 的「感覺」 ,NO Fu。
舉個例子來說,儘管不見得說得出精確的定義,我們都知道「筆電、桌機、手機、平板」是什麼。面對不同的情境、不同的工作需求,選擇適當的裝置這種思維過程,對於我們大多數人都是輕而易舉的,這就是老頭對於 JAX 想要找到的感覺。
具體而言,這個感覺就是:
老頭沒有能力用一次貼文的篇幅,把「JAX 是什麼 ? 」這個問題說清楚,讓讀者們找到對的感覺。唯有藉著「JAX 好好玩」這一系列的貼文,把老頭的探索過程分享給大家,期待於未來的某一天,你能夠突然雲開見日,找到自己的感覺。
撇開感覺不談,老頭總是要把「JAX 是什麼 ? 」做一個交待。
JAX 是 Google Brain (R. Frostig, M.J. Johnson, C. Leary) 在 2018 SysML Conference 發表的 (https://research.google/pubs/pub47008/ ),其論文標題是:
Compiling machine learning programs via high-level tracing
它的摘要是:
We describe JAX, a domain-specific tracing JIT compiler for generating high-performance accelerator code from pure Python and Numpy machine learning programs. JAX uses the XLA compiler infrastructure to generate optimized code for the program subroutines that are most favorable for acceleration, and these optimized subroutines can be called and orchestrated by arbitrary Python. Because the system is fully compatible with Autograd, it allows forward- and reverse-mode automatic differentiation of Python functions to arbitrary order. Because JAX supports structured control flow, it can generate code for sophisticated machine learning algorithms while maintaining high performance. We show that by combining JAX with Autograd and Numpy we get an easily programmable and highly performant ML system that targets CPUs, GPUs, and TPUs, capable of scaling to multi-core Cloud TPUs.
從這段摘要中,我們可以概略的了解 JAX 是什麼:
一言以蔽之,JAX 提供一套工具及環境,讓執行機器學習任務的 Python 及 Numpy 程式,能夠有效率的在 CPU、GPU、以及 TPU 上運作。
接下來,老頭針對其中的幾項功能,做一些概略的介紹:
編譯器(compiler)
而這套工具的最主要功能,是提供一個具有 high-level tracing 技術的編譯器 (compiler),來編譯機器學習的程式碼。而 JAX 它要編譯的對象,是 Python 和 Numpy 程式碼。
Python 基本上是一個採用「直譯 (interpret)」式的程式語言[4.1],優點是方便、執行時容易除錯,缺點是沒效率。使用編譯方式,就是要改善這一個缺點。
Numpy API
原始的 Numpy 程式碼,只能在 CPU 上執行,並不支援在 GPU 及 TPU 上運行。JAX 另外提供了一組和 Numpy 幾乎完全一樣的 API,配合 JAX 上的 Python 編譯器,可以使 Numpy 程式碼執行的速度大增數十倍(甚至於數百倍)!老頭稍後會詳細介紹這一部份。
Just in Time 式的編譯
JIT (just-in-time) 是指編譯發生的時間,是在原始碼執行的時候才進行。傳統的編譯式程式語言,像是 C 及 C++,在寫完原始碼之後直接編譯(及連結函式庫),產生執行檔供未來執行用。而 JAX 並不這樣做,它的 JIT 編譯原則是:
* 不是所有的原始碼都要編譯。在撰寫程式的時候,由程式設計師來指定那些函式 (Functions) 需要 JIT 編譯。
* 程式執行時,第一次碰到指定的函式時,即進行 JIT 編譯,把編譯結果暫存起來,並執行它。
* 上述編譯及暫存的動作,就是 Tracing 。
* XLA 則是 JAX 所採用的編譯器。
* 程式執行時,再次碰到指定函式時,就直接把暫存的編譯結果拿來用,不用再編一次 (此敍述簡化了許多重要的檢查,請大家暫時先這麼理解,之後會有貼文詳細說明)。
* 程式結束後,暫存的編譯結果也隨之被清除,並不會將它們以檔案的形式保存起來。
Autograd 自動梯度計算
autograd 是用來計算微分和偏微分,藉以得到梯度函數 (gradient function) 的一組 API。在機器學習 (深度學習) 中,反向傳播 (backpropagation; BP) 是最重要的運算之一,反向傳播的過程,其實就是計算當時各個模型參數 (model parameters) 相對於損失函數 (loss function) 的梯度,並用這個梯度來調整參數。
雜雜亂亂說了一堆,在這總結一下:
從這篇論文看來,Google 發展 JAX 的初衷,是為了要提升機器學習程式的運算效率,使其善用 CPU、GPU、及 Google 自己的 TPU,表面上,似乎並沒有把它發展成機器學習框架的企圖。然而在 2022 年的此時,為什麼有那麼多人,把 JAX 視為機器學習框架 (或準框架) 並給予高度的期待?這一點就留待日後慢慢來談。
而目前 Google 官方,仍然把 JAX 定位為一個實驗性的開源專案。外界普遍的推測(或者是說期待),要嘛 Google 會將 JAX 發展的更完整,讓它取代 TensorFlow 成為新一代的 Google 官方機器學習框架;要嘛 Google 會將 JAX 整合進下一版 TensorFlow (version 3) ,就像它在 2.0 版把 Keras 整合進來一般。到底會怎麼發展,就讓我們拭目以待吧!
老頭在此只對 JAX 做些初步的介紹,並不期待這短短的幾段文字,真的能把 JAX 說清楚、講明白。對於那些求知若渴,非得在第一時間弄清楚 JAX 是什麼的讀者,老頭推薦以下幾篇文章,請大家自行去閱讀了!
註
[4.1] 說 Python 是直譯式的程式語言,可能會造成某些爭議,Jack Cheng 在 Medium 上的文章「Python進階技巧 (5) — Python 到底怎麼被執行?直譯、編譯、字節碼、虛擬機看不懂?」把 Python 碼執行的方式說得很清楚,想要深入了解的讀者,趕快去看 ( Python進階技巧 )