(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載)
雖然 jax.numpy 是依據 Numpy 的語法和語義來設計的,但仍有幾個不同點需要注意。第一個要介紹的是「虛擬亂數產生器 (Pseudo Random Number Generator; PRNG)」。
在 Numpy 中要產生亂數,所使用的演算法叫 MT19937,它是「梅森旋轉法 Mersenne Twister」的一個變體 [6.1]。 亂數是藉由一個「全域狀態 (global state) 」來產生的,每一個亂數產生的操作,都會導致全域狀態的改變。
在這裏大家要注意「全域」這兩個字,意指一個 Python 進程 (process) 中的所有執行緒 (thread) 皆共用同一個全域狀態,不同執行緒之間的亂數產生行為彼此相互影響,這個是不個不太好的設計。
Numpy 提供 np.random.seed() 來設定全域狀態的初始值,例如:
np.random.seed(0)
全域狀態是由 624 個 32-bit unsigned int (產生亂數用) 和數個屬性參數 [6.2] 所構成 (在此不細談),Numpy 提供 np.random.get_state() 來獲取全域狀態的內容:
prng_state = np.random.get_state()
index = 0
for element in prng_state:
print(f'Tuple {index} type: {type(element)} - \
{element.shape if isinstance(element,np.ndarray) else ""}')
index += 1
output:
Tuple 0 type: <class 'str'>
Tuple 1 type: <class 'numpy.ndarray'> - (624,)
Tuple 2 type: <class 'int'>
Tuple 3 type: <class 'int'>
Tuple 4 type: <class 'float'>
全域狀態唯一決定了下一個亂數所產生的值,從以下的程式片斷可以看出來,只要重設全域狀態至相同的初始值,其後產生的亂數都是相同的。
np.random.seed(7)
print(np.random.uniform())
print(np.random.uniform())
print("==============================")
np.random.seed(7)
print(np.random.uniform())
print(np.random.uniform())
output:
0.07630828937395717
0.7799187922401146
'=============================='
0.07630828937395717
0.7799187922401146
MT19937 有以下的缺點,致使 JAX 在設計時,決定捨棄它,採用更新的方法。
我們先用一個例子來說明 JAX PRNG 的使用方法。
from jax import random
# get the key
key = random.PRNGKey(0)
# split the key
key, subkey = random.split(key)
# use subkey
print(random.normal(subkey, shape=(2,)))
# splict the key again
key, subkey = random.split(key)
# usb subkey
print(random.normal(subkey, shape=(2,)))
output:
[ 0.19307722 -0.52678293]
[ 0.00870701 -0.04888523]
首先要注意的是,JAX 有關亂數產生的 API 是放在 jax.random 之下的,而非 jax.numpy 下:
from jax import random
使用前,要先產生一個 key ,JAX PRNG 是利用 key 來產生亂數,每一個亂數生成 API 都需要輸入 key 值。
key = random.PRNGKey(0)
key 不要直接用,要先用 jax.randdom.split() API 分割成兩個 key [6.4],其中一個 (key) 保留起來,另外一個 (subkey) 可以用來產生亂數。
key, subkey = random.split(key)
print(random.normal(subkey, shape=(2,)))
要再一次產生亂數前,要先分割上次保留的 key,保留一個,使用一個,如此生生不息。
key, subkey = random.split(key)
print(random.normal(subkey, shape=(2,)))
以下的流程圖,詳細說明了 JAX PRNG 的使用方法:
我們可以總結 JAX PRNG 不同於 Numpy 的特性:
註:
[6.1] 參考 維基百科 。
[6.2] 可以參考 numpy 的使用者手冊 。
[6.3] 有關 Big Crush 可參考 http://www.iro.umontreal.ca/~lecuyer/myftp/papers/testu01.pdf 。另外可以參考 這份報告 ,在 R 的 MT19937 實作上,會有 2 個 Big Crush 項目失敗。
[6.4] 分割成兩個 key 是一個過份簡化的說法,其實 split() 可以用參數指定分割後的數量,在後續的貼文中老頭會加以說明。