JAX 官方文件中有一個很好的範例 (Training a Simple Neural Network, with PyTorch Data Loading) ,老頭將其稍稍改寫了一下,並加上一些註解,放在這裏,讀者們可以下載來執行看看。
這個例子以 Pytorch 來載入 MNIST 資料集,並將其轉換為 DeviceArray 格式。JAX 並沒有定義處理 dataset 相關的 API,當然也沒有像 Pytorch 和 TensorFlow 一般包裝好了一些著名的 dataset 供人使用。因此,有關於資料集的處理,JAX 程式設計師必須仰賴既有的 AI 框架所提供的服務。總之,只要這些資料最終能轉為 Numpy 陣列格式, JAX 就可以使用它。
這個例子也介紹了如何使用 vmap 自動向量化的功能。它的重點是:
# 定義模型預測函式 : 單一 image 預測
# ==================================================================
def predict(params, image):
activations = image
for layer in params[:-1]:
w = layer['w']
b = layer['b']
outputs = jnp.dot(activations,w) + b
activations = relu(outputs)
final_w, final_b = (params[-1]['w'],params[-1]['b'])
logits = jnp.dot(activations,final_w) + final_b
return jnp.exp(logits) / jnp.sum(jnp.exp(logits)) # return softmax
# 自動批次的預測函式
# ==================================================================
# predict(params, image)
# params: 不做 auto vectorization (對應 in_axes 的 None)
# image: 對第一維度做 auto vectorization (對應 in_axes 的 0)
batched_predict = jax.vmap(predict, in_axes=(None, 0))
# 損失函式
# ==================================================================
# MSE (Mean Squared Error)
def loss(params, images, targets):
preds = batched_predict(params, images) # 要參考自動向量化的預測函式版本
return jnp.mean((targets-preds)**2)
if CFG_UnitTest:
random_flattened_images = jrand.normal(jrand.PRNGKey(1), (2, 28 * 28))
random_targets = jnp.array([[1.,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,1.]], dtype=jnp.float32)
print(loss(HP_Params, random_flattened_images, random_targets))
另外,老頭在這個範例程式中,加入了我慣用的「單元測試 unit test 」的手法。在重要的函式及程式片斷後,用一個 if 控制結構 (參考上面的程式片斷 if CFG_UnitTest: 部份) 來包裝單元測試程式段。在程式開發的過程中,我會將 CFG_UnitTest 設為 True,每寫完一段程式馬上做測試。等到開發完成,再將 CFG_UnitTest 設為 False ,程式執行時就會自動略過單元測試。這個部份也提出來供大家參考。
OK, 大家可以直接去跑老頭提供的 colab 筆記本,Good Luck!