iT邦幫忙

2022 iThome 鐵人賽

DAY 7
0
AI & Data

JAX 好好玩系列 第 7

JAX 好好玩 (7) : JAX.NUMPY (3) : 再探 JAX PRNG

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載

上回老頭介紹了 JAX PRNG 的基本使用方法,現在讓我們更深一層的來探討它。

JAX PRNG Key

前面提到,Numpy 的亂數產生器是依據擁有 642 個 unsigned int32 所組成的「全域狀態」來運作的,而 JAX 則是用 PRNG Key。那麼這個 key 的資料結構是什麼?

import jax
key = jax.random.PRNGKey(7)
key

output:
DeviceArray([0, 7], dtype=uint32

key 是一個 DeviceArray 類別,一維陣列,兩個元素,元素的型態為 unsigned int32。JAX key 只佔用 64 個 bits,相較於 Numpy 全域狀態的 2.5Kb,小得太多了。當然,由於 key 在使用時必須將以分割,而且對於程式中不同的執行緒及不同的功能模塊,JAX 都建議用 random.PRNGKey() 來產生不同的 key,所以程式在實際運行時,key 所需的空間是 64 x n bits (n 為同時存在的 key 數量),雖是如此,大概也遠小於 2.5Kb。

DeviceArray 類別,它是 JAX 自行定義的陣列類別,定義在 jax.numpy.DeviceArray,它的角色等同於 Numpy 中的 ndarray,程式中 JAX 要處理和產生的資料,皆以 DeviceArray 格式來存放。 未來我們會更詳細的介紹 DeviceArray。

# 用 isinstance() 來檢查 key 的類別
import jax
key = jax.random.PRNGKey(1)
isinstance(key, jax.numpy.DeviceArray)

output:
True

更有彈性的 key 分割

在呼叫 jax.random.split() 時,其實我們可以利用參數「num=」 來指定 key 分割的數量,習慣上,我們保留分割後的第一個 key 而用其他的 key 來產生亂數。

import jax
key = jax.random.PRNGKey(3)
key, *subkeys = jax.random.split(key, num=5)

subkeys 是一個含有 4 個 key 的 list,可以分別使用。

print(type(subkeys))
print(f'{[type(element) for element in subkeys]}')

output:
<class 'list'>
[<class 'jaxlib.xla_extension.DeviceArray'>, <class 'jaxlib.xla_extension.DeviceArray'>, <class 'jaxlib.xla_extension.DeviceArray'>, <class 'jaxlib.xla_extension.DeviceArray'>]

在 JAX 官方文件中有一個例子,它試著用一個 key 呼叫 API 產生 3 個 (shape = (3,)) 標準常態分布亂數,然後用同一個 key ,利用 split() 產生 3 個 key,並用這 3 個 key 分別各產生 1 個常態分布亂數,結果前者和後者所產生出來的亂數集合是不同的。

import jax
import jax.numpy as jnp
import numpy as np
from jax import random

key = random.PRNGKey(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)

key = random.PRNGKey(42)
print("all at once: ", random.normal(key, shape=(3,)))

output:
individually: [-0.04838832 0.10796154 -1.2226542 ]
all at once: [ 0.18693547 -1.2806505 -1.5593132 ]

很有趣的實驗,可以讓我們對 JAX key 的使用更有感覺 !


上一篇
JAX 好好玩 (6) : JAX.NUMPY (2) : 虛擬亂數產生器
下一篇
JAX 好好玩 (8) : JAX.NUMPY (4) : 用了才知道它的快
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言