(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載
JAX JIT 使用原則之一:要被 jax.jit() 編譯的函式,應該要是純函式 (pure function) ,否則容易出現意想不到的問題!我們先來看看什麼是純函式,再來研究會出什麼問題。
一個純函式必須符合以下的條件 [13.1]:
有那些原因會造成相同的輸入,卻有不同的輸出呢?這裏列舉一些:
副作用是在函式運算過程中,造成系統或外部世界狀態的改變。常見的副作用為:
撰寫 JAX JIT 所編譯的函式時,要儘量避免副作用,否則程式編譯後的行為,會超出我們的預期。
以下是一些不純函式 (impure function) 的例子:
# side effect, output
#=========================================================================
def impure_print_side_effect(x):
print("Executing function") # This is a side-effect
return x
# refer to global variable
#=========================================================================
g1 = 0.
def impure_uses_globals(x):
return x + g1
# side effect, change global variable
#=========================================================================
g2 = 0.
def impure_saves_global(x):
global g2
g2 = x
return x
老頭在這僅僅簡單的介紹純函式,讓大家能有基本的概念。讀者若想要更進一步研究純函式,可以參考「功能式程式設計 (functional programming)」的相關文獻。
要徹底了解不純函式會造成的問題,必須先知道 JAX JIT 編譯器的運作方式。這個部份老頭日後會介紹給大家,現階段我們可以先看看幾個例子,給大家一些初步的概念:
1. 當函式裏含有 print() 這個輸出指令時,JAX JIT 會對它執行結果造成什麼影響?
def impure_print_side_effect(x):
print("Executing function!!!!") # This is a side-effect
return x
impure_jit_fun = jax.jit(impure_print_side_effect)
print(f'1st execution: {impure_jit_fun(99)}')
print('===========================================')
print(f'2nd execution: {impure_jit_fun(99)}')
print('===========================================')
print(f'3rd execution: {impure_jit_fun(99)}')
output:
Executing function!!!!
1st execution: 99
'==========================================='
2nd execution: 99
'==========================================='
3rd execution: 99
注意!只有當函式第一次呼叫的時候,print("Executing function!!!!") 才有效果,之後第二次及第三次呼叫則看不到 print() 的訊息!為什麼?大家先把疑惑放在心裏。
2. 接下來我們來實驗一下全域變數的影響:
下面的例子是函式內參考全域變數的值。
g1 = 0.
def impure_uses_globals(x):
return x + g1
impure_jit_fun = jax.jit(impure_uses_globals)
# 1st call, at this moment g1 = 0
print(f'1st call: {impure_jit_fun(99.9)}')
print('===========================================')
# not change g1 to 1.
g1 = 1.
print(f'Currently g1 = {g1}')
print('===========================================')
# 2nd call, at this moment g1 = 0
print(f'2nd call: {impure_jit_fun(99.9)}')
output:
*1st call: 99.9000015258789
'==========================================='
Currently g1 = 1.0
'==========================================='
2nd call: 99.9000015258789
程式的第二次呼叫 impure_jit_fun() 前已經將全域變數 g1 改為 1. 了,為什麼其回傳值並沒有加 1 呢?
3. 現在來看一個去更新全域變數的例子,請注意!結果可能會讓你嚇到下巴脫臼!
g2 = 0.
def impure_saves_global(x):
global g2
g2 = x
return x
impure_jit_fun = jax.jit(impure_saves_global)
print(f'Before 1st call, g2 = {g2}')
print('===========================================')
print(f'1st call: {impure_jit_fun(1.0)}')
print('===========================================')
print(f'After 1st call, g2 = {g2}')
output:
*Before 1st call, g2 = 0.0
'==========================================='
1st call: 1.0
'==========================================='
After 1st call, g2 = Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
函式被呼叫之後,g2 竟然變成一個奇怪的物件 Traced<ShapedArray….>!到底發生了什麼事?
現階段,為了避免奇奇怪怪的事情發生,我們先謹記,只將 JAX JIT 應用在純函式上。
註:
[13.1] 本文中有純函式的說明,主要是參考 [維基百科中文及英文版上面的資料](https://en.wikipedia.org/wiki/Pure_function , https://zh.wikipedia.org/wiki/%E7%BA%AF%E5%87%BD%E6%95%B0 )。
[13.2] 在某些條件限制下,純函式是可以有 I/O 的,即是所謂的 I/O 單子 (I/O monad) 概念,但這一系列的貼文並不準備討論它。