在更進一步「玩」JAX 之前,老頭想要回到我們對於 JAX 最初的疑問 – JAX 是什麼 ? 前面零零碎碎的 39 篇貼文,似乎有點見樹不見林,仍舊沒有辨法給大家一個全面的概念。在這一個貼文裏,老頭會試著用較為宏觀的角度來看 JAX,回答「JAX 到底是什麼 ? 」這個問題。
「零件」是組成一個整體的基本單位。JAX 提供各類的零件,藉以組裝成一個神經網路模型,並確保這個模型可以高效率的訓練 (training) 及推理 (inference)。這些零件包括:
jax.numpy : 與 numpy 幾乎完全相同的 API
jax.DeviceArray : 可以直接在 GPU/TPU 上運作的陣列資料結構
jax.jit : 可以即時編譯函式的轉換機制 (transformation)
jax.vmap:自動向量化 (Automatic Vectorization) 機制,用以處理批次資料。
jax.grad:自動微分機制,用於偏微分及導數 (或梯度) 的計算。
jax control flow:針對純函式的需求所設計的控制流程 (control flow)。
pytree:JAX 定義的資料結構,用來包裝模型參數 (model parameters) 和資料集 (dataset entries) 等。
之前的貼文,已經對這些零件做了初步的介紹。另外還有些老頭尚未說明的零件,其中重要的有:
jax.pmap:多加速器的平行運算機制。
jax.scipy:JAX 版本的 scipy 函式庫及工具包。
如果所有的模型都必須從零件開始組裝,將不僅費時費力,也沒有必要。於是乎許多開源專案試著將 JAX 的零件組合成功能更完整的「組件」,方便使用者更快速簡捷的打造神經網路。老頭之前介紹的 Flax 即是其中一例。其他重要的專案包括:
Haiku:由 DeepMind 提供的類似 Flax 的神經網路組件。
Rlax:支援增強學習的組件。
FedJax:支援聯邦學習的組件。
Jraph:由 DeepMind 提供的支援 graph 神經網路的組件。
Equinox:類似 Flax 及 Haiku,但其特別強調像 pytorch 的 API。
如果沒有要用特別的方法(如聯邦學習等),老頭建議大家要熟悉 Flax 及 Haiku,這樣大致可以應付大部份應用上的需求。
台灣業界的 AI 專案以應用為主,老頭建議如果用 TensorFlow 或 Pytorch 能完成的工作,就不需要用 JAX 來做。但是在以下的狀況,JAX 將是非常好的選擇。
JAX 的學習曲線 (learning curve) 較一般的 AI 框架長很多,也更需要完備的 AI 基礎知識,因此,如果想要導入 JAX 的組織,要及早訓練你的成員。在開始學習 JAX 之前,你的團隊當然需要對深度學習及神經網路有著相當程度的了解, 也要先熟悉「純函式 (pure function)」 和 「功能式程式設計 (functional programming)」。目前 JAX 提供的除錯 (debugging) 機制有限,計劃開始前完整的教育訓練,是彌補這一缺點的好方法。