iT邦幫忙

2022 iThome 鐵人賽

0
AI & Data

JAX 好好玩系列 第 35

JAX 好好玩 (35) : Flax (1) : 準備學習 Flax

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載 )

在之前的貼文中,老頭介紹了如何用 JAX 設計並訓練簡單的 MLP 神經網路模型。然而,如果今天我們想要用像是 CNN、RNN、Transformer 等架構,難道我們也要像先前的例子般,用 Python 和 JAX 一步步的把模型建構起來?當然不是!

隨著 JAX 的逐漸成熟,有越來越多的單位及個人,以 JAX 為基礎,開發了支援神經網路及機器學習中高階模型和演算法的函式庫,Flax 即是其中之一,起初它是由 Google Brain 團隊所設計,稍後並成為開放專案。Flax 和 JAX 的原始設計者都來自 Google Brain,兩個團隊密切配合,因此 Flax 在提供高階神經網路相關的功能時,也可以充分保留 JAX 的優點 (效率、彈性等)。

接下來老頭會以 Flax 官方文件「Getting Started」[35.1] 為藍本,稍微加以改寫,來跟大家介紹 Flax 到底怎麼用。在此之前,先補充些先備知識。

準備環境

目前 (2022/10/13) Colab 並沒有預載 flax,因此,我們必須用 pip 自行安裝。可以用如下的 try / except 結構,未來如果 colab 預載了 colab,就不致於重覆 import 了。

try:
    import flax
except ModuleNotFoundError: # Install flax if missing
    !pip install --quiet flax
    import flax

此外,要完整的訓練一個神經網路,除了 Flax 之外,我們還需要 import 另外一個函式庫 Optax [35.2]。Optax 提供了「最佳化模型參數」的相關 API,也就是所謂的 optimizer API。起初它是被安排在 jax.experimental.optix 封裝內,是 JAX 的一部份,後來其研發團隊決定將其分割出來為獨立的開源專案,並改名為 Optax。目前其貢獻者主要來自 Deep Mind 及 Alphabet。

Colab 已經預載了 Optax,可以直接 import:

import optax                           # Optimizers

在稍後的「Getting Started」範例程式中,會使用到 Optax 兩個 API。

optax.softmax_cross_entropy(logits=logits, labels=labels_onehot)

這是一個非常方便好用的 API,它先將 logit 轉成 softmax,再和 onehot label 比對算出 cross entropy。如果在損失函式中使用這個 API,模型設計的時候,就只要輸出 logit 就可以了。

optim_standard = optax.sgd(learning_rate, learning_momentum)

範例程式將選用 Optax 提供的 sgd (Stochastic Gradient Descent) API,作為參數最佳化的演算法。

jax.value_and_grad() API

在「Getting Started」範例程式中,使用到 jax.value_and_grad() API,它是 jax.grad() 的一個變形,我們在此先看看它的用法。

假設 fun1() 是我們的損失函式,我們想要計算出損失值及其導數,一般來說,我們會分別呼叫 fun1() 和 jax.grad(fun1)()。

def fun1(x):
    y = x**3 + 2*x**2 - 3*x + 1
    return y
 
x = 5.
print(fun1(x))
print(jax.grad(fun1)(x))

output:
161.0
92.0

若使用 jax.value_and_grad(),就只要呼叫一次就好,它會同時回傳 tuple,其包括 fun1() 和 jax.grad(fun1)() 的計算結果。

print(jax.value_and_grad(fun1)(x))

output:
(DeviceArray(161., dtype=float32, weak_type=True), DeviceArray(92., dtype=float32, weak_type=True))

若是我們的損失函式不僅要算出損失值,也要負責回傳某些其他的資訊,如以下的函式 fun2(),我們在呼叫 jax.grad() 和 jax.value_and_grad() 時,就要帶 has_aux=True 參數,告訴 JAX 我們的損失函式會多回傳一個 auxiliary value。

# y 是此函式主要的輸出, 也是我們希望計算 grad() 的輸出.
# z 是此函式的額外輸出
# ===========================================================================================
def fun2(x):
    y = x**3 + 2*x**2 - 3*x + 1
    z = x + 1
    return y, z
 
x = 5.
print(fun2(x))
print(jax.grad(fun2, has_aux=True)(x))
print(jax.value_and_grad(fun2, has_aux=True)(x))

output:
https://ithelp.ithome.com.tw/upload/images/20221017/201296169008F5ismz.png

OK,大家應該準備好來研究第一個 Flax 範例程式了!

參考:

[35.1] Flax Documents: Getting Started .
[35.2] Optax 官方文件網址 : https://optax.readthedocs.io/en/latest/


上一篇
JAX 好好玩 (34) : 類別與 jit (2) : 註冊類別為 pytree
下一篇
JAX 好好玩 (36) : Flax (2) : 第一個範例程式
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言