iT邦幫忙

2022 iThome 鐵人賽

0
AI & Data

JAX 好好玩系列 第 39

JAX 好好玩 (39) : Flax (5) : 輔助函式及單一批次訓練函式

  • 分享至 

  • xImage
  •  

這篇貼文,乃是針對「第二個範例程式」中的「4. 輔助函式」及「5.1. 單一 batch 訓練函式」的部份,加以詳細的說明。

損失函式

Flax 建議的損失函式寫法是:

  • 輸入參數為 (1) 一個批次的模型輸出 (我們的例子是一組 logits) 和 (2) 其對應的 labels。
  • 定義一個子函式,可稱之為「單一筆資料的損失函式」。它的輸入參數是單一筆模型輸出,及單一筆 label。
  • 利用 vmap() 轉換子函式,令其可以處理一個批次。
  • 利用 mean() 算出整個 batch 的損失值,它必須是一個純量。
    https://ithelp.ithome.com.tw/upload/images/20221103/201296165ngweJfEAf.png

參數調整 (最佳化) 方式

模型參數的調整,也就是所謂的「參數最佳化 optimizing」。Optax 為 JAX 提供了許多實作好的最佳化演算法,包括 adam、sgd 等等,詳細列表可參考 Optax 官方文件網頁 [39.1]。

# 定義兩種最佳化的方法
# 1. 「隨機梯度下降法(Stochastic gradient descent, SGD)」.
# 2. 「 adam 法」 
 
learning_rate = 0.01
learning_momentum = 0.9
 
optim_sgd = optax.sgd(learning_rate, learning_momentum)
optim_adam = optax.adam(learning_rate)
 

「訓練狀態」及單一批次的訓練函式

Flax 提供了「訓練狀態」這一個物件,來追踪訓練過程,使程式更為精簡 [39.2]。我們首先要創立一個訓練狀態的案例,並提供 (1) 模型的 apply 函式 、(2) 模型的參數、 (3) 模型的參數調整方式。

state1 = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

而在每一個批次訓練時,算出 grad() 並利用「訓練狀態」來調整模型的參數。
https://ithelp.ithome.com.tw/upload/images/20221103/20129616Fn5xnLcEFm.png

這些準備工作完成之後,就可以執行訓練了。

[39.1] 可參考 Optax API 之 Common Optimizers 的說明。

[39.2] 在 Flax 官方文件中,也提到了可以用 Optax 提供的方式來輔助訓練過程中的參數調整,讀者可參考其提供的 範例程式 Optimizing with Optax。比較 Flax 的「訓練狀態」, Optax 的方法稍微複雜了一點。
https://ithelp.ithome.com.tw/upload/images/20221103/20129616MlJwpTIkOq.png


上一篇
JAX 好好玩 (38) : Flax (4) : 自訂模型
下一篇
JAX 好好玩 (40) : JAX 到底是什麼 ?
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言