這篇貼文,乃是針對「第二個範例程式」中的「自訂模型」的部份,加以詳細的說明。
Flax 提供兩種方式來定義「使用者自訂模型」,一為「明確的 explict」宣告法,一為「精簡的 compact 」(或可稱為『行內的 in-line』) 宣告法。不管使用那一種方式,所有自訂的模型,都必須繼承 flax.linen.Module。習慣上大家都用 import flax.linen as nn 來載入 linen 封裝 (package), 使得大多數的範例程式中,我們看到的都是 nn.Module 這種寫法。
範例程式中,自訂模型要解決的是「cifar10 分類問題」。cifar10 資料集包括了 10 種類型的 (32 x 32 x 3) RGB 圖片,模型的輸入及輸出,將符合它的圖片規格。cifar10 資料集的重要參數定義如下:
# cifar10 資料集內共有 10 種圖形, 如下:
DS_Labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
DS_ClassNumber = len(DS_Labels)
# cifar10 的圖片尺寸
DS_ImageShape = (32,32,3)
DS_ImageFlatened = 32*32*3
因此所設計的模型,其輸人維度應該是 (批次, 32, 32, 3),其輸出應該是 (批次, 10)。
範例中的第一個模型,是以明確方式定義的 MLP 模型。MLP 模型是以數個「密集連結層 (dense layer)」,在這裏使用了一個技巧,以 keyward parameter (即程式中的 features ) 在宣告類別案例 (instance) 的時候,指定模型的超參數。
model_expmlp = ExplicitMLP(features=[512,256,128,64,32,10])
這個例子指定了 6 層 dense layers,以及每層對應的神經元個數 (在 Flax 中,稱其為 feature 個數)。
接著,必須在 setup(self) 類別函式中,宣告這個模型所包含的子層 (或著子模型)。我們使用 6 個密集連結層為此模型的子層,Flax 提供了預先定義好的 nn.Dense(),我們直接使用它就可以了。
self.layers = [nn.Dense(feat) for feat in self.features]
如果不想寫得那麼有技巧,可以一層一層個別的宣告。
self.layer1 = nn.Dense(512)
self.layer2 = nn.Dense(256)
self.layer3 = nn.Dense(128)
self.layer4 = nn.Dense(64)
self.layer5 = nn.Dense(32)
self.layer6 = nn.Dense(10)
大家可以查閱 Flax 官方文件 [38.1],看看 Flax 已經預先定義了那些子層可供使用。
在 setup() 之後,必須宣告 __call__(self, inputs) 類別函式,描述模型從輸入到輸出的計算流程。
def __call__(self, inputs):
x = inputs.reshape((-1,DS_ImageFlatened))
for i, layer in enumerate(self.layers):
x = layer(x)
if i != len(self.layers) - 1:
x = nn.relu(x) # 除了最後一層外, 其他層皆輸出 relu
return x # 回傳 logits
首先利用 reshape() 把輸入的圖片攤平,才能成為 Dense 子層的輸入,而後依序呼叫 6 個 Dense 子層。除了最後一層之外,每個子層皆使用 nn.relu() 作為激活函式 (activation function)。
大家可以查閱 Flax 官方文件 [38.1],看看 Flax 已經預先定義了那些子層可供使用。
精簡的方式只要宣告 __call__(self, inputs) 這個類別函式就可以了,但是要加上 @nn.compact 修飾字。和明確的宣告法一樣,我們需要__call__(self, inputs) 裏設計模型由輸入到輸出的運算流程,直接使用定義好的子層 (或先前定義好的自訂模型類別)。以這個例子來講, 我們使用了 Flax 定義好的卷積層 (nn.Conv)、最大池化層 (nn.max_pool)、及密集連結層 (nn.Dense)等等。
模型宣告是,指定了模型結構及運算流程,但是並沒有指定輸入資料的維度,而「初始化模型」的目的,即是在於以「虛擬輸入」來指定模型的輸入資料維度,進而決定模型最終的所需的參數數量。初始化的另一個目的,是指定模型參數的初始值,因此,在初始化時,我們傳給它 PRNG key ,以生成隨機的初始值。
初始化的目的
- 決定模型參數數量
- 指定模型參數初始值
一般來說,虛擬的輸入資料要包括批次維度,以一筆資料為一個批次即可。
# 初始化時需要 key
key = jrand.PRNGKey(3)
key, subkey1 = jrand.split(key)
model_cmpcnn = CompactCNN()
# 呼叫 init(), 並將參數保留下來
# 虛擬輸入 data :
# -- 要使用批次維度,但一筆資料即可。
params_cmpcnn = model_cmpcnn.init(subkey1, jnp.ones((1,32,32,3)))
init() 傳回模型的參數,要保留下來,接下來的模型訓練及模型儲存,要用到它。
Flax 的自訂模型遵循了 JAX 的精神,參數和計算分開,保持模型計算的「純粹性 pure」。
[38.1] 可以參考 flax.linen package,