這一個範例程式的目的,是給大家一個整體的概念,看看要設計及訓練一個 Flax 神經網路模型要做那些事情。老頭改寫了 Flax 官網文件上的 Getting Started 程式,希望能夠給讀者們一個更清晰的程式架構。程式的 colab 檔可以由此下載,大家先去跑一次,再讀讀程式中的註解,應該就可以對於 Flax 有些初步的認識。
之前老頭曾以 Pytorch 來載入 MNIST 資料集,這一回 Flax 選用 TensorFlow 的 dataset 服務來載入,做為訓練的標的。目前 JAX 的生態系裏還沒有看到自行定義的 dataset 服務,應該也沒有什麼必要自己再弄一套,利用現有的就可以了。
這個範例老頭分別以 Flax 中的「完整寫法」及「精簡寫法」[36.1]設計兩個結構完全一樣的 CNN 神經網路,之後分別訓練這兩個網路作為比較。這個設計有點模仿 Keras 中的 Sequential 機制,對於簡單的模型,「精簡寫法」更為簡潔。
另外要注意的是「Flax 模型的參數是獨立於模型之外的」!原因在於必須維持模型運算的「純粹 pure」,相同的輸入,得到相同的輸出 (讀者可以參考老頭先前有關純函式的貼文)。模型本身只保留運算流程,而把訓練 (或推理)資料及參數當做輸入,那麼整個模型的計算,就能維持純函式的特性,利於 JIT 編譯以加快運算速度。
有了整體的感覺之後,接著就可以進入細節了。
[36.1] Flax 正式的名稱為「明確的 explicitly」宣告法及「行內的 in-line」宣告法。