(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載 )
另外一個讓 class 型別和 jit (以及 JAX 其他 API) 相容的方法,是註冊此一 class 為 pytree 容器 (container)。JAX 既定的 pytree 容器有 list, tuple 和 dict,使用者自訂的 class 會被視為葉節點 (leaf node) ,要讓一個自訂 class 成為一個 pytree 容器,必須要:
由 pytree 對於 list, tuple 及 dict 攤平和重組的操作,我們可以學到如何在 class 中實作這兩個方法。pytree 提供的 API 分別是 tree_flatten() 和 tree_unflatten(),tree_flatten() 傳入一個 pytree 變數,傳回兩個值,第一個是攤平的葉節點 list,第二個是樹結構定義 (PyTreeDef) ;tree_unflatten() 傳入樹結構定義和攤平的葉節點 list,傳回一個 pytree。從以下的程式段可以知道它們的用法:
# 宣告一個 pytree 變數
tree_01 = [1.,(2., 3.),{'age':55, 'name':'edward'}]
# 用 tree_flatten() 來取出葉節點及 pytree 的結構
flat_leaves_01, flat_struct_01 = jtree.tree_flatten(tree_01)
print(type(flat_leaves_01), type(flat_struct_01))
print("=================================================")
print(flat_leaves_01)
print("=================================================")
print(flat_struct_01)
output :
# 用 tree_unflatten() 來重組 pytree
tree_unflat_01 = jtree.tree_unflatten(flat_struct_01, flat_leaves_01)
print(tree_unflat_01)
output :
[1.0, (2.0, 3.0), {'age': 55, 'name': 'edward'}]
要讓 class 成為一個 pytree 容器,首先得在 class 內實作符合這個 class 的 tree_flatten() 和 tree_unflatten() 方法,然後,再呼叫 register_pytree_node() 來註冊這個類別及其特有的方法就可以了:
# 定義一個 user-defined class, 並定義其 _tree_flatten, _tree_unflatten()
# ====================================================================================
class MyClass03():
def __init__(self, x=1.0, y=1.0):
self.x = x
self.y = y
def _tree_flatten(self):
children = (self.x,self.y) # arrays / dynamic values
aux_data = None # aux_data 要傳回重組此 class 的資訊,以這個例子來說
# 傳回 children 就夠了, 因此, 設為 None 即可
return (children, aux_data)
@classmethod # 注意! 這個修飾是必要的
def _tree_unflatten(cls, aux_data, children):
return cls(*children) # 呼叫 class 的 __init__ 來建構新的 class 案例. 此例中
# 不需要 aux_data.
# 不必再宣告為 partial
@jax.jit
def my_func_03(cls: MyClass03, addition):
return cls.x + cls.y + addition
# register MyClass03
jtree.register_pytree_node(MyClass03,
MyClass03._tree_flatten,
MyClass03._tree_unflatten)
my_class03 = MyClass03(1.0, 2.0)
print(f'Before Modification: {my_func_03(my_class03, 3.0)}')
print(f' {hash(my_class03)}')
# 修改 my_class 的內容
my_class03.x = 3.0
my_class03.y = 4.0
print(f'After Modification: {my_func_03(my_class03, 3.0)}')
print(f' {hash(my_class03)}')
output :
在程式中有相關的說明,但另外要注意的,my_class03 在修改前後其 hash() 值保持不變,因為我們並沒有實作 hash,但此並不會影響 jax tracing,只要將 MyClass03 註冊為 pytree 即可。