(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載)
我們更進一步的來研究當 DeviceArray 陣列索引 (index) 超過範圍 (out-of-bounds) 時, JAX 是怎麼來處理的。
先看 Numpy,當程式嘗試著用超過範圍的索引來存取一個陣列時,Numpy 會產生一個執行時的「索引錯誤 (IndexError)」:
x = np.ones(shape=(5,4))
print(x[9,3])
output:
…
IndexError: index 9 is out of bounds for axis 0 with size 5
而 JAX 卻無法使用這一個策略,原因在於 JAX 會優先使用 GPU 或 TPU 來執行陣列相關運算,而在 GPU/TPU 上,非常困難 (甚至於不太可能) 偵測執行期間的程式例外 (Exception) 及錯誤。JAX 採取很有彈性的作法來處理這個問題,我們就來看看 JAX 目前的既定 (default) 策略:
x = jnp.arange(5)
y1 = x.at[6].set(99)
y2 = x.at[2:7].set(88)
print(f'Original array: {x}')
print(f'y1 : {y1}')
print(f'y2 : {y2}')
output:
Original array: [0 1 2 3 4]
y1 : [0 1 2 3 4]
y2 : [ 0 1 88 88 88]
x = jnp.arange(5)
y1 = x.at[6].get() # out-of-bounds,
y2 = x.at[-2].get() # note! minus indexing follows python semantics
print(f'Original array: {x}')
print(f'y1 : {y1}')
print(f'y2 : {y2}')
output:
Original array: [0 1 2 3 4]
y1 : 4
y2 : 3
要更進一步的來控制 JAX 處理索引的行為, JAX 在 DeviceArray.at[]. 的各個運算,提供了 mode= 參數,並提供以下的參數值可供選擇:
舉一些例子:
x = jnp.arange(5.0)
y1 = x.at[10].add(10,mode='clip') # index clip to "4"
y2 = x.at[10].get(mode='clip') # index clip to "4"
y3 = x.at[10].get(mode='drop')
y4 = x.at[10].get(mode='fill', fill_value=-99.99)
y5 = x.at[10].get(mode='fill')
print(f'Original array: {x}')
print(f'y1 : {y1}')
print(f'y2 : {y2}')
print(f'y3 : {y3}')
print(f'y4 : {y4}')
print(f'y5 : {y5}')
output:
*Original array: [0. 1. 2. 3. 4.]
y1 : [ 0. 1. 2. 3. 14.]
y2 : 4.0
y3 : nan
y4 : -99.98999786376953
y5 : nan
JAX 官方文件強調,目前 (撰寫這段文字的時候為 2022/09/06) mode= 機制仍在實驗中,未來可能會有進一步的優化,讀者如果必須注意 JAX 不同版本之間的差別。