(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載
我們現在來探討另外一個 JAX JIT 使用上的限制 [14.1]:函式運算的過程中,所有的陣列都必須是靜態的維度 ( it requires all arrays to have static shapes.)。這是什麼意思?
先看一個符合標準的例子:
def norm_tedious(X):
mean = X.mean(0)
std = X.std(0)
y = (X - mean) / std
return y, mean, std
input = jnp.array([1,2,3,4])
print(f'input shape : {input.shape}')
y, mean, std = norm_tedious(input)
print(f'y shape : {y.shape}')
print(f'mean shape : {mean.shape}')
print(f'std shape : {std.shape}')
output:
*input shape : (4,)
y shape : (4,)
mean shape : ()
std shape : ()
當輸入參數 X 的維度固定之後,這個函式內所有的陣列運算維度都是固定的,這個就是靜態的維度。注意!它並不是要求輸入參數 X 的維度要固定,X 可以是任何合理的維度;它是說當 X 在某一個維度時,函式內的陣列維度都是固定的。
上例中,老頭故意回傳函式 norm_tedious() 內的運算過程中的變數,並印出它們的維度。當輸入參數 input 的維度是 (4,) 時,輸出值 y 的維度必然是 (4,),而中間運算結果 mean 和 std 必然是純量。
下面的例子就不符合標準:
def get_negatives(X):
y = X[X<0]
return y
get_negatives_jit = jax.jit(get_negatives)
input1 = jnp.array([1,2,3,4])
input2 = jnp.array([-1,2,-3,4])
output1 = get_negatives(input1)
output2 = get_negatives(input2)
print(f'input1 shape : {input1.shape}')
print(f'input2 shape : {input2.shape}')
print(f'output1 shape : {output1.shape}')
print(f'output2 shape : {output2.shape}')
output:
*input1 shape : (4,)
input2 shape : (4,)
output1 shape : (0,)
output2 shape : (2,)
函式 get_negatives() 回傳值的維度,不僅僅是由輸入參數 X 的維度決定,還要依據輸入參數的內含值。上例,input1 和 input2 是相同維度但含有不同的值,它們所對應的 output1 和 output2 則不同。如此,就違反了 JAX JIT 的要求。
如果硬要將這個不合適的函式用 JAX JIT 執行,執行時將會報錯:
output = get_negatives_jit(input1)
output:
'---------------------------------------------------------------------------'
*UnfilteredStackTrace Traceback (most recent call last)
in
2
----> 3 output = get_negatives_jit(input1)
…..
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[4])
讀者可能會納悶,那麼多的使用限制,我們為什麼還要用 JAX JIT 呢?老頭的看法是效率!它快到我們捨不得放下它。目前 JAX 仍在發展中,截至目前 (2022/09/08) ,它的版本是 0.3.17 (2022/09/01 釋出) [14.2],仍舊是測試版。我們期待未來,JAX 社群能夠發展出更好的工具,讓程式設計師能夠檢查其函式是否符合 JIT 的要求,提出警告及建議,而有效避免程式執行時的意外狀況及錯誤的結果。
[14.1] 本文主要參考了 JAX 文件官網 To JIT or not to JIT 一文的內容。
[14.2] 可由 pypi 的網站查到最新釋出的版本。https://pypi.org/project/jax/