(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載 )
如果一個函式,其「具有 class 類別型態」的輸入參數,我們可以用 jit 修飾它嗎?先看一個例子:
# 定義一個 user-defined class
# ====================================================================================
class MyClass01():
def __init__(self, x=1.0, y=1.0):
self.x = x
self.y = y
# 定義 function
# ====================================================================================
@jax.jit
def my_func_01(cls: MyClass01, addition):
return cls.x + cls.y + addition
my_class = MyClass01(1.0, 2.0)
try:
my_func_01(my_class, 3.0)
except TypeError as e:
print(f"There's a TypeError")
print(e)
output :
從以上的程式片段可以知道,使用者自訂的 class 型態是不相容於 jit 的。當使用 jit 來編譯這一類的函式時,會產生 TypeError 例外。
第一個解法是使用「偏函式的靜態參數」宣告,老頭在先前的貼文中提到過這個方法:
# 將類別型態的輸入參數宣告為 static
# ====================================================================================
@partial(jax.jit, static_argnums=0)
def my_func_01(cls: MyClass01, addition):
return cls.x + cls.y + addition
my_func_01(my_class, 3.0)
output :
DeviceArray(6., dtype=float32, weak_type=True)
這種解法有一個嚴重的限制!不能隨意修改這個 class 類別的參數值!以下面這段程式來說明:
my_class = MyClass01(1.0, 2.0)
print(f'Before Modification: {my_func_01(my_class, 3.0)}')
# 修改 my_class 的內容
my_class.x = 3.0
my_class.y = 4.0
print(f'After Modification: {my_func_01(my_class, 3.0)}')
output :
Before Modification: 6.0
After Modification: 6.0
class 類別值的修改,並不會反應在第二次呼叫上!為什麼呢?
在第一次呼叫時,JAX/JIT 將第一個參數 my_class 視為常數,將其編入可執行碼裏,執行後,把這個可執行碼暫存起來。
在第二次呼叫時,JAX/JIT 並沒有辦法發覺 my_class 內容已經被修改過了,而是認為它和第一次呼叫的 my_class 是一樣的,因此 JAX 就直接執行剛才暫存的執行碼。
好,大家一定接著想問,為什麼 JAX 會認為第一次呼叫時的 my_class 和第二次呼叫時的 my_class 是一樣的?
原因在於 Python 中 class 型別的 hash 機制 ! 請看下列程式段:
my_class = MyClass01(1.0, 2.0)
print(f'hash before modification: {hash(my_class)}')
my_class.x = 3.0
my_class.y = 4.0
print(f'hash after modification: {hash(my_class)}')
output :
hash before modification: 8728844717001
hash after modification: 8728844717001
在 my_class 修改前和修改後 hash(my_class) 的值都是一樣的,Python 既定的 class 型別的 hash() 算法,並不會把 class attribute 的值列入考量。而不巧的是,JAX 就是利用 hash() 來判斷 my_class 是不是相同。
因此,完整的解法是要重新定義 my_class 計算 hash 的方法。
# 定義一個 user-defined class, 並定義其 __hash__ 和 __eq__
# ====================================================================================
class MyClass02():
def __init__(self, x=1.0, y=1.0):
self.x = x
self.y = y
def __hash__(self):
return hash((self.x, self.y))
def __eq__(self, other):
return (isinstance(other, MyClass02)) and\
(self.x, self.y) == (other.x, other.y)
@partial(jax.jit, static_argnums=0)
def my_func_02(cls: MyClass02, addition):
return cls.x + cls.y + addition
my_class02 = MyClass02(1.0, 2.0)
print(f'Before Modification: {my_func_02(my_class02, 3.0)}')
print(f' {hash(my_class02)}')
# 修改 my_class 的內容
my_class02.x = 3.0
my_class02.y = 4.0
print(f'After Modification: {my_func_02(my_class02, 3.0)}')
print(f' {hash(my_class02)}')
output :
要注意的是,當我們修改 class 中的 hash() 算法時,也要同時修正 eq() 的算法,以確保兩者的語義保持一致。這一部份老頭就不多著墨,有興趣的讀者可以去參考 Python 的 Data Model [33.1]。
參考:
[33.1] 可參考 Python Data Model: hash。