iT邦幫忙

2022 iThome 鐵人賽

DAY 11
0
AI & Data

JAX 好好玩系列 第 11

JAX 好好玩 (11) : JAX.NUMPY (7) : 其他 jax.numpy 和 Numpy 的不同點

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載

JAX (jax.numpy) 在設計的時候,就儘量依循 Numpy 的語法和語義,但是為了效率及其他的考量,不得不有一些偏離。在此之前,老頭已紹介紹以下的不同之處:

  • 虛擬亂數產生器 (PRNG)
  • 陣列資料結構及類別 (Array Data Structure, ndarray vs DeviceArray)
  • 陣列資料的不可變性 (Immutibility)
  • 陣列索引超過範圍的處理 (Out-of-bonds Indexing)

現在我們來探討其他的不同點。

非陣列型資料的處理

在 Numpy 中,我們很習慣用其 API 來處理 Python 的 list 或 tuple,例如:

np.sum([1,2,3])
np.sum((1,3,5))

但在 jax.numpy,大部份的 API 皆不接受 list 及 tuple,這種狀況會導致程式抛出 TypeError 的例外 (exception):

try:
  jnp.sum([1, 2, 3])
except TypeError as e:
  print(f"TypeError: {e}")

output:
TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.

要解決這個問題,我們必須用 jax.numpy.array() 先將 list (or tuple) 轉換成 DeviceArray 類別,而後再進行所須的運算:

jnp.sum(jnp.array([1,2,3]))

output:
DeviceArray(6, dtype=int32)

不安全的型態轉換 (unsafe type cast)

當執行不安全的型態轉換時,JAX 的實作和其運作的平台有關,可以被視為未定義的。而其結果,通常和 Numpy 的結果不同。例如:

print(np.arange(254.0, 258.0).astype('uint8'))
print(jnp.arange(254.0, 258.0).astype('uint8'))

output:
*[254 255 0 1]
[254 255 255 255]

以老頭在 colab 上做的實驗,jax.numpy 會把 (float32) 256.0 轉成 255,而 Numpy 則轉成 0。然而如果是以下的程式段:

print(np.arange(254, 258).astype('uint8'))
print(jnp.arange(254, 258).astype('uint8'))

output:
*[254 255 0 1]
[254 255 0 1]

兩者的結果又一致了。老頭建議,讀者們在 JAX 中儘量避免不安全的型態轉換,如果一定要的話,應該要先在平台上做一些試驗,確保你的平台所生出的結果是你所能預期的。

此外,Numpy 的 ndarray.astype() API 提供了 casting= 參數選項,讓我們能精確的掌握轉換的模式,而 JAX 的 DeviceArray.astype() 目前並沒有提供類似的參數。

小結論

這一回對於 jax.numpy 的簡介都到此告一段落。老頭之所以先介紹 jax.numpy 給大家,是因為它馬上就可以用!不論你現在用的框架是什麼,它都可以協助你加速陣列資料的前處理,節省你非常多的時間。 


上一篇
JAX 好好玩 (10) : JAX.NUMPY (6) : 超過範圍的索引
下一篇
JAX 好好玩 (12) : JAX JIT (1) : 開啓執行效率之門
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言