iT邦幫忙

2022 iThome 鐵人賽

DAY 12
0
AI & Data

JAX 好好玩系列 第 12

JAX 好好玩 (12) : JAX JIT (1) : 開啓執行效率之門

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載

之前提到過的 Google Brain 2018 年的論文 Compiling machine learning programs via high-level tracing 裏,開宗明義就說:

… JAX 就是一個針對特定領域,使用 tracing 和 JIT 方式的編譯器,以產生可執行於 GPU、TPU 等硬體加速器上的機器碼...
(… JAX, a domain-specific tracing JIT compiler for generating high-performance accelerator code …)

接下來,老頭想和大家分享我研讀 JIT 的一些心得。

JIT 是 Just-In-Time 的頭字語 (Acronym),表示「即時」、「適時」的意思。JIT 強調在需要的當下提供服務,不早也不晚。以餐飲為例,自助餐就不是一個 JIT 的服務,菜早就準備好了,放在枱上等客人取用;而熱炒店比較像是 JIT,要吃的時候現炒給你。JIT 的編譯器也是如此,它不像是 C/C++ 編譯器,在程式撰寫完成後馬上編譯成執行檔,放在硬碟中等用戶執行;它是在程式執行的時候才執行編譯的程序,編譯後立馬執行。使用 JIT,我們可以在「直譯式語言的彈性及便利」和「編譯式語言的效率及精確」間取得平衡。

我們來體會一下 JIT 可以帶來的效率提升:

#標準的 import
import jax
import jax.numpy as jnp
import numpy as np

以 Numpy 做為比較基準

size = 1000000
np.random.seed(0)

def selu_np(x, alpha=1.67, lmda=1.05):
    return lmda * np.where(x>0, x, alpha*np.exp(x)-alpha)

x = np.random.normal(size=(size,))
%timeit selu_np(x)

output:
*41.7 ms ± 4.31 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

將 Numpy 換成 jax.numpy,觀察效率的提升

size = 1000000
key = jax.random.PRNGKey(0)

def selu_jnp(x, alpha=1.67, lmda=1.05):
    return lmda * jnp.where(x>0, x, alpha*jnp.exp(x)-alpha)

x = jax.random.normal(key, (size,))
%timeit selu_jnp(x).block_until_ready()

output:
*1.27 ms ± 71.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

jax.numpy 再加上 JIT ,看看最終得到什麼樣的成果!第一種寫法:使用 jax.jit() API

size = 1000000
key = jax.random.PRNGKey(0)

def selu_jnp(x, alpha=1.67, lmda=1.05):
    return lmda * jnp.where(x>0, x, alpha*jnp.exp(x)-alpha)

selu_jnp_jit = jax.jit(selu_jnp)  # make function "selu_jnp" a jit function.

x = jax.random.normal(key, (size,))

%timeit selu_jnp_jit(x).block_until_ready()

output:
*136 µs ± 2.31 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

第二種寫法:@jax.jit 函式裝飾字

size = 1000000
key = jax.random.PRNGKey(0)

@jax.jit 
def selu_jnp(x, alpha=1.67, lmda=1.05):
    return lmda * jnp.where(x>0, x, alpha*jnp.exp(x)-alpha)

x = jax.random.normal(key, (size,))

%timeit selu_jnp(x).block_until_ready()

output:
*135 µs ± 1.17 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

(註:JAX JIT 的兩種語法,在執行上不會有效率上的差別,大家可依狀況,選擇適合的語法)

41.7 ms 與 136 µs,超過 300 倍的效率提升,這是相當驚人的。

不知道從上面程式片斷大家有沒有發覺,我們用 JIT 加速的不是整個程式,而是一個函式 (function),請大家先了解這一點。JAX 使用的是針對特定領域 (domain-specific) 即時編譯器,並不是泛用 (general-purpose) 的編譯器,必有其使用上的限制。未來老頭會詳談 JAX 的 JIT 背後運作的原理,到時,大家就會知道為什麼會有這樣的限制。


上一篇
JAX 好好玩 (11) : JAX.NUMPY (7) : 其他 jax.numpy 和 Numpy 的不同點
下一篇
JAX 好好玩 (13) : JAX JIT (2) : 純函式 (Pure Function)
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言