(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載)
JAX (jax.numpy) 在設計的時候,就儘量依循 Numpy 的語法和語義,但是為了效率及其他的考量,不得不有一些偏離。在此之前,老頭已紹介紹以下的不同之處:
現在我們來探討其他的不同點。
在 Numpy 中,我們很習慣用其 API 來處理 Python 的 list 或 tuple,例如:
np.sum([1,2,3])
np.sum((1,3,5))
但在 jax.numpy,大部份的 API 皆不接受 list 及 tuple,這種狀況會導致程式抛出 TypeError 的例外 (exception):
try:
jnp.sum([1, 2, 3])
except TypeError as e:
print(f"TypeError: {e}")
output:
TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.
要解決這個問題,我們必須用 jax.numpy.array() 先將 list (or tuple) 轉換成 DeviceArray 類別,而後再進行所須的運算:
jnp.sum(jnp.array([1,2,3]))
output:
DeviceArray(6, dtype=int32)
當執行不安全的型態轉換時,JAX 的實作和其運作的平台有關,可以被視為未定義的。而其結果,通常和 Numpy 的結果不同。例如:
print(np.arange(254.0, 258.0).astype('uint8'))
print(jnp.arange(254.0, 258.0).astype('uint8'))
output:
*[254 255 0 1]
[254 255 255 255]
以老頭在 colab 上做的實驗,jax.numpy 會把 (float32) 256.0 轉成 255,而 Numpy 則轉成 0。然而如果是以下的程式段:
print(np.arange(254, 258).astype('uint8'))
print(jnp.arange(254, 258).astype('uint8'))
output:
*[254 255 0 1]
[254 255 0 1]
兩者的結果又一致了。老頭建議,讀者們在 JAX 中儘量避免不安全的型態轉換,如果一定要的話,應該要先在平台上做一些試驗,確保你的平台所生出的結果是你所能預期的。
此外,Numpy 的 ndarray.astype() API 提供了 casting= 參數選項,讓我們能精確的掌握轉換的模式,而 JAX 的 DeviceArray.astype() 目前並沒有提供類似的參數。
這一回對於 jax.numpy 的簡介都到此告一段落。老頭之所以先介紹 jax.numpy 給大家,是因為它馬上就可以用!不論你現在用的框架是什麼,它都可以協助你加速陣列資料的前處理,節省你非常多的時間。