(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載
之前提到過的 Google Brain 2018 年的論文 Compiling machine learning programs via high-level tracing 裏,開宗明義就說:
… JAX 就是一個針對特定領域,使用 tracing 和 JIT 方式的編譯器,以產生可執行於 GPU、TPU 等硬體加速器上的機器碼...
(… JAX, a domain-specific tracing JIT compiler for generating high-performance accelerator code …)
接下來,老頭想和大家分享我研讀 JIT 的一些心得。
JIT 是 Just-In-Time 的頭字語 (Acronym),表示「即時」、「適時」的意思。JIT 強調在需要的當下提供服務,不早也不晚。以餐飲為例,自助餐就不是一個 JIT 的服務,菜早就準備好了,放在枱上等客人取用;而熱炒店比較像是 JIT,要吃的時候現炒給你。JIT 的編譯器也是如此,它不像是 C/C++ 編譯器,在程式撰寫完成後馬上編譯成執行檔,放在硬碟中等用戶執行;它是在程式執行的時候才執行編譯的程序,編譯後立馬執行。使用 JIT,我們可以在「直譯式語言的彈性及便利」和「編譯式語言的效率及精確」間取得平衡。
我們來體會一下 JIT 可以帶來的效率提升:
#標準的 import
import jax
import jax.numpy as jnp
import numpy as np
以 Numpy 做為比較基準
size = 1000000
np.random.seed(0)
def selu_np(x, alpha=1.67, lmda=1.05):
return lmda * np.where(x>0, x, alpha*np.exp(x)-alpha)
x = np.random.normal(size=(size,))
%timeit selu_np(x)
output:
*41.7 ms ± 4.31 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
將 Numpy 換成 jax.numpy,觀察效率的提升
size = 1000000
key = jax.random.PRNGKey(0)
def selu_jnp(x, alpha=1.67, lmda=1.05):
return lmda * jnp.where(x>0, x, alpha*jnp.exp(x)-alpha)
x = jax.random.normal(key, (size,))
%timeit selu_jnp(x).block_until_ready()
output:
*1.27 ms ± 71.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
jax.numpy 再加上 JIT ,看看最終得到什麼樣的成果!第一種寫法:使用 jax.jit() API
size = 1000000
key = jax.random.PRNGKey(0)
def selu_jnp(x, alpha=1.67, lmda=1.05):
return lmda * jnp.where(x>0, x, alpha*jnp.exp(x)-alpha)
selu_jnp_jit = jax.jit(selu_jnp) # make function "selu_jnp" a jit function.
x = jax.random.normal(key, (size,))
%timeit selu_jnp_jit(x).block_until_ready()
output:
*136 µs ± 2.31 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
第二種寫法:@jax.jit 函式裝飾字
size = 1000000
key = jax.random.PRNGKey(0)
@jax.jit
def selu_jnp(x, alpha=1.67, lmda=1.05):
return lmda * jnp.where(x>0, x, alpha*jnp.exp(x)-alpha)
x = jax.random.normal(key, (size,))
%timeit selu_jnp(x).block_until_ready()
output:
*135 µs ± 1.17 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
(註:JAX JIT 的兩種語法,在執行上不會有效率上的差別,大家可依狀況,選擇適合的語法)
41.7 ms 與 136 µs,超過 300 倍的效率提升,這是相當驚人的。
不知道從上面程式片斷大家有沒有發覺,我們用 JIT 加速的不是整個程式,而是一個函式 (function),請大家先了解這一點。JAX 使用的是針對特定領域 (domain-specific) 即時編譯器,並不是泛用 (general-purpose) 的編譯器,必有其使用上的限制。未來老頭會詳談 JAX 的 JIT 背後運作的原理,到時,大家就會知道為什麼會有這樣的限制。