iT邦幫忙

2022 iThome 鐵人賽

DAY 10
0
AI & Data

JAX 好好玩系列 第 10

JAX 好好玩 (10) : JAX.NUMPY (6) : 超過範圍的索引

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載

我們更進一步的來研究當 DeviceArray 陣列索引 (index) 超過範圍 (out-of-bounds) 時, JAX 是怎麼來處理的。

Numpy 會抛出錯誤

先看 Numpy,當程式嘗試著用超過範圍的索引來存取一個陣列時,Numpy 會產生一個執行時的「索引錯誤 (IndexError)」:

x = np.ones(shape=(5,4))
print(x[9,3])

output:

IndexError: index 9 is out of bounds for axis 0 with size 5

JAX 的既定策略

而 JAX 卻無法使用這一個策略,原因在於 JAX 會優先使用 GPU 或 TPU 來執行陣列相關運算,而在 GPU/TPU 上,非常困難 (甚至於不太可能) 偵測執行期間的程式例外 (Exception) 及錯誤。JAX 採取很有彈性的作法來處理這個問題,我們就來看看 JAX 目前的既定 (default) 策略:

  1. 如果是值設定 (update) 的動作,當索引超過範圍時,JAX 忽略這個設定。例如:
x = jnp.arange(5)
y1 = x.at[6].set(99)
y2 = x.at[2:7].set(88)

print(f'Original array: {x}')
print(f'y1            : {y1}')
print(f'y2            : {y2}')

output:
Original array: [0 1 2 3 4]
y1 : [0 1 2 3 4]
y2 : [ 0 1 88 88 88]

  1. 如果是取值的動作,當索引超過上界, JAX 回傳位於上界的元素 (element) 值。
  2. 當索引為負數時, JAX 遵循 Python 的語法原則。
x = jnp.arange(5)
y1 = x.at[6].get()  # out-of-bounds,
y2 = x.at[-2].get() # note! minus indexing follows python semantics

print(f'Original array: {x}')
print(f'y1            : {y1}')
print(f'y2            : {y2}')

output:
Original array: [0 1 2 3 4]
y1 : 4
y2 : 3

mode= 參數

要更進一步的來控制 JAX 處理索引的行為, JAX 在 DeviceArray.at[]. 的各個運算,提供了 mode= 參數,並提供以下的參數值可供選擇:

  • “promise_in_bounds”:這是既定的選項,其行為如上所述。
  • “clip”: 一率把超過範圍的索引改變,將之指向最上界。
  • “drop”:一率忽略超過範圍的索引。
  • “fill” with “fill_value=”:忽略設定值的動作,讀值時,傳回 fill_value= 的參數值,若沒指定,則回傳 nan。

舉一些例子:

x = jnp.arange(5.0)
y1 = x.at[10].add(10,mode='clip')  # index clip to "4"
y2 = x.at[10].get(mode='clip')    # index clip to "4"
y3 = x.at[10].get(mode='drop')
y4 = x.at[10].get(mode='fill', fill_value=-99.99) 
y5 = x.at[10].get(mode='fill')   

print(f'Original array: {x}')
print(f'y1            : {y1}')
print(f'y2            : {y2}')
print(f'y3            : {y3}')
print(f'y4            : {y4}')
print(f'y5            : {y5}')

output:
*Original array: [0. 1. 2. 3. 4.]
y1 : [ 0. 1. 2. 3. 14.]
y2 : 4.0
y3 : nan
y4 : -99.98999786376953
y5 : nan

JAX 官方文件強調,目前 (撰寫這段文字的時候為 2022/09/06) mode= 機制仍在實驗中,未來可能會有進一步的優化,讀者如果必須注意 JAX 不同版本之間的差別。


上一篇
JAX 好好玩 (9) : JAX.NUMPY (5) : DeviceArray 初探
下一篇
JAX 好好玩 (11) : JAX.NUMPY (7) : 其他 jax.numpy 和 Numpy 的不同點
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言