iT邦幫忙

2022 iThome 鐵人賽

3
AI & Data

JAX 好好玩系列 第 40

JAX 好好玩 (40) : JAX 到底是什麼 ?

  • 分享至 

  • xImage
  •  

在更進一步「玩」JAX 之前,老頭想要回到我們對於 JAX 最初的疑問 – JAX 是什麼 ? 前面零零碎碎的 39 篇貼文,似乎有點見樹不見林,仍舊沒有辨法給大家一個全面的概念。在這一個貼文裏,老頭會試著用較為宏觀的角度來看 JAX,回答「JAX 到底是什麼 ? 」這個問題。

JAX 提供建造及訓練「高效率神經網路模型」的零件

零件」是組成一個整體的基本單位。JAX 提供各類的零件,藉以組裝成一個神經網路模型,並確保這個模型可以高效率的訓練 (training) 及推理 (inference)。這些零件包括:

jax.numpy : 與 numpy 幾乎完全相同的 API
jax.DeviceArray : 可以直接在 GPU/TPU 上運作的陣列資料結構
jax.jit : 可以即時編譯函式的轉換機制 (transformation)
jax.vmap:自動向量化 (Automatic Vectorization) 機制,用以處理批次資料。
jax.grad:自動微分機制,用於偏微分及導數 (或梯度) 的計算。
jax control flow:針對純函式的需求所設計的控制流程 (control flow)。
pytree:JAX 定義的資料結構,用來包裝模型參數 (model parameters) 和資料集 (dataset entries) 等。

之前的貼文,已經對這些零件做了初步的介紹。另外還有些老頭尚未說明的零件,其中重要的有:

jax.pmap:多加速器的平行運算機制。
jax.scipy:JAX 版本的 scipy 函式庫及工具包。

許多開源專案,將 JAX 零件組合成更高階的「組件」

如果所有的模型都必須從零件開始組裝,將不僅費時費力,也沒有必要。於是乎許多開源專案試著將 JAX 的零件組合成功能更完整的「組件」,方便使用者更快速簡捷的打造神經網路。老頭之前介紹的 Flax 即是其中一例。其他重要的專案包括:

Haiku:由 DeepMind 提供的類似 Flax 的神經網路組件。
Rlax:支援增強學習的組件。
FedJax:支援聯邦學習的組件。
Jraph:由 DeepMind 提供的支援 graph 神經網路的組件。
Equinox:類似 Flax 及 Haiku,但其特別強調像 pytorch 的 API。

如果沒有要用特別的方法(如聯邦學習等),老頭建議大家要熟悉 Flax 及 Haiku,這樣大致可以應付大部份應用上的需求。

什麼時候要用到 JAX (及其「組件」)

台灣業界的 AI 專案以應用為主,老頭建議如果用 TensorFlow 或 Pytorch 能完成的工作,就不需要用 JAX 來做。但是在以下的狀況,JAX 將是非常好的選擇。

  • 如果你要訓練一個超大型的模型,速度及效率是重要的因素。
  • 如果你是一個研究人員,要建構一個新的神經網路模型結構。
  • 如果你想要加速現有的 Numpy 資料處理。
  • 當在現有的框架下 (如 TensorFlow 或 Pytorch),模型訓練的時間太長,不符合專案的需求。

JAX 需要更長的學習曲線

JAX 的學習曲線 (learning curve) 較一般的 AI 框架長很多,也更需要完備的 AI 基礎知識,因此,如果想要導入 JAX 的組織,要及早訓練你的成員。在開始學習 JAX 之前,你的團隊當然需要對深度學習及神經網路有著相當程度的了解, 也要先熟悉「純函式 (pure function)」 和 「功能式程式設計 (functional programming)」。目前 JAX 提供的除錯 (debugging) 機制有限,計劃開始前完整的教育訓練,是彌補這一缺點的好方法。


上一篇
JAX 好好玩 (39) : Flax (5) : 輔助函式及單一批次訓練函式
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言