iT邦幫忙

2022 iThome 鐵人賽

DAY 13
0
AI & Data

JAX 好好玩系列 第 13

JAX 好好玩 (13) : JAX JIT (2) : 純函式 (Pure Function)

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載

JAX JIT 使用原則之一:要被 jax.jit() 編譯的函式,應該要是純函式 (pure function) ,否則容易出現意想不到的問題!我們先來看看什麼是純函式,再來研究會出什麼問題。

什麼是純函式?

一個純函式必須符合以下的條件 [13.1]:

  • 對於相同的輸入值,這個函式必須產生相同的輸出值。
  • 這個函式不能有副作用 (side effects),而影響到程式其他部份的執行結果。
  • 基本上,函式內不能有 I/O 動作 [13.2]。

有那些原因會造成相同的輸入,卻有不同的輸出呢?這裏列舉一些:

  • 函式的輸出,參考了全域變數 (global variable)。
  • 函式的輸出,參考了會改變的靜態變數 (static variable)。
  • 函式中有輸入 (input) 指令,而函式的輸出,參考了當時的輸入值。

副作用是在函式運算過程中,造成系統或外部世界狀態的改變。常見的副作用為:

  • 更改檔案
  • 更改全域變數 (或靜態變數)
  • 輸出資料
  • 觸發外部動作 (如發送一個 http 請求)

撰寫 JAX JIT 所編譯的函式時,要儘量避免副作用,否則程式編譯後的行為,會超出我們的預期。

以下是一些不純函式 (impure function) 的例子:

# side effect, output 
#=========================================================================
def impure_print_side_effect(x):
  print("Executing function")  # This is a side-effect 
  return x
# refer to global variable
#=========================================================================
g1 = 0.
def impure_uses_globals(x):
  return x + g1
# side effect, change global variable
#=========================================================================
g2 = 0.
def impure_saves_global(x):
  global g2
  g2 = x
  return x

老頭在這僅僅簡單的介紹純函式,讓大家能有基本的概念。讀者若想要更進一步研究純函式,可以參考「功能式程式設計 (functional programming)」的相關文獻。

不純函式會造成什麼問題?

要徹底了解不純函式會造成的問題,必須先知道 JAX JIT 編譯器的運作方式。這個部份老頭日後會介紹給大家,現階段我們可以先看看幾個例子,給大家一些初步的概念:

1. 當函式裏含有 print() 這個輸出指令時,JAX JIT 會對它執行結果造成什麼影響?

def impure_print_side_effect(x):
  print("Executing function!!!!")  # This is a side-effect 
  return x

impure_jit_fun = jax.jit(impure_print_side_effect)

print(f'1st execution: {impure_jit_fun(99)}')
print('===========================================')
print(f'2nd execution: {impure_jit_fun(99)}')
print('===========================================')
print(f'3rd execution: {impure_jit_fun(99)}')

output:
Executing function!!!!
1st execution: 99
'==========================================='
2nd execution: 99
'==========================================='
3rd execution: 99

注意!只有當函式第一次呼叫的時候,print("Executing function!!!!") 才有效果,之後第二次及第三次呼叫則看不到 print() 的訊息!為什麼?大家先把疑惑放在心裏。

2. 接下來我們來實驗一下全域變數的影響:

下面的例子是函式內參考全域變數的值。

g1 = 0.
def impure_uses_globals(x):
  return x + g1

impure_jit_fun = jax.jit(impure_uses_globals)

# 1st call, at this moment g1 = 0
print(f'1st call: {impure_jit_fun(99.9)}')
print('===========================================')

# not change g1 to 1.
g1 = 1.
print(f'Currently g1 = {g1}')
print('===========================================')

# 2nd call, at this moment g1 = 0
print(f'2nd call: {impure_jit_fun(99.9)}')

output:
*1st call: 99.9000015258789
'==========================================='
Currently g1 = 1.0
'==========================================='
2nd call: 99.9000015258789

程式的第二次呼叫 impure_jit_fun() 前已經將全域變數 g1 改為 1. 了,為什麼其回傳值並沒有加 1 呢?

3. 現在來看一個去更新全域變數的例子,請注意!結果可能會讓你嚇到下巴脫臼!

g2 = 0.
def impure_saves_global(x):
  global g2
  g2 = x
  return x

impure_jit_fun = jax.jit(impure_saves_global)

print(f'Before 1st call, g2 = {g2}')
print('===========================================')
print(f'1st call: {impure_jit_fun(1.0)}')
print('===========================================')
print(f'After 1st call, g2 = {g2}')

output:
*Before 1st call, g2 = 0.0
'==========================================='
1st call: 1.0
'==========================================='
After 1st call, g2 = Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>

函式被呼叫之後,g2 竟然變成一個奇怪的物件 Traced<ShapedArray….>!到底發生了什麼事?

現階段,為了避免奇奇怪怪的事情發生,我們先謹記,只將 JAX JIT 應用在純函式上。

註:

[13.1] 本文中有純函式的說明,主要是參考 [維基百科中文及英文版上面的資料](https://en.wikipedia.org/wiki/Pure_function , https://zh.wikipedia.org/wiki/%E7%BA%AF%E5%87%BD%E6%95%B0 )。

[13.2] 在某些條件限制下,純函式是可以有 I/O 的,即是所謂的 I/O 單子 (I/O monad) 概念,但這一系列的貼文並不準備討論它。


上一篇
JAX 好好玩 (12) : JAX JIT (1) : 開啓執行效率之門
下一篇
JAX 好好玩 (14) : JAX JIT (3) : 函式內陣列的維度
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言