iT邦幫忙

2022 iThome 鐵人賽

DAY 29
0
AI & Data

JAX 好好玩系列 第 29

JAX 好好玩 (29) : Pytree

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載

簡介

Pytree 是 JAX 定義的資料結構,依照 JAX 官方文件 [29.1] :

a pytree is a container of leaf elements and/or more pytrees. Containers include lists, tuples, and dicts. A leaf element is anything that’s not a pytree, e.g. an array. In other words, a pytree is just a possibly-nested standard or user-registered Python container. If nested, note that the container types do not need to match. A single “leaf”, i.e. a non-container object, is also considered a pytree.

  • Pytree 是一個「容器 container」,它包括 Python 裏定義好的 list / tuple / dict。
  • 這個容器內的元素,可以是另一個容器 (亦即另一個 Pytree) ,或容器外的其他資料型態 (稱之為葉節點 leaf)。
  • 如此輾轉下去,可以成為一個樹狀資料結構。

舉一些 Pytrees 的例子:

# 一層的 Pytree,內含三個葉節點。
[1, 'a', object()]
# 二層的 Pytree。
(1, (2, 3), ())
# 三層的 Pytree。
 [1, {'k1': 2, 'k2': (3, 4)}, 5]

值得注意的是,我們常用的 DeviceArray 資料結構,它是一個葉節點,而非容器。

另外,None 在 jax Pytree 裏被視為「空容器」,而非葉節點。

在 JAX 裏,Pytree 常常用來包裝 (1) 模型參數 model parameters ,(2) 資料集 dataset entries 和 (3) RL agent observations,使得它們便於管理及儲存。

常用的 Pytree API

在這裏老頭介紹幾個 Pytree 常用的 API,完整的 API 函式列表,可以參考 [29.2]。

https://ithelp.ithome.com.tw/upload/images/20221003/201296169YKSez2ezU.png

tree_leaves () 會將 tree 中的葉節點一一取出,並將它們放置在 list 中回傳。可以把它視為壓平整個 tree 。

[按:在早期的 JAX 版本中,tree_leaves 是直接放在 jax 封裝下(jax.tree_leaves),所以在網路上比較舊的範例程式中,讀者們還可以看到這樣的呼叫方式。]

import jax.tree_util as jtree
 
example_trees = [
    [1, 'a', object()],
    (1, (2, 3), ()),
    [1, {'k1': 2, 'k2': (3, 4)}, 5],
    {'a': 2, 'b': (2, 3)},
    jnp.array([1, 2, 3]),
]
 
# Let's see how many leaves they have:
for pytree in example_trees:
  leaves = jtree.tree_leaves(pytree)
  print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")

output:
https://ithelp.ithome.com.tw/upload/images/20221003/20129616tL1L6EhCjh.png

前面提到,None 在 jax Pytree 裏被視為「空容器」,而非葉節點。可以用下列的程式片斷來檢測這個說法:

jtree.tree_leaves([None, None, None])

output:
[]

is_leaf= 的用法,是在變更 Pytree 判斷樹中某一節點是不是葉節點的方式。假設我們希望暫時地將 list 或 dict 視為葉結節點時,我們就可以分別定義「判別函式」,去判斷輸入的資料型態是不是 list / dict ,如果是,就回傳 True,告訴 jax.tree_util 它是葉結點。

在呼叫 tree_leaves() 時,利用 is_leaf= 指定判別函式。

# to force a container type as leaf.
# ==============================================================================
 
# force list to be a leaf
def check_list(x):
    return isinstance(x, list)
 
# force dict to be a leaf
def check_dict(x):
    return isinstance(x, dict)
 
print(jtree.tree_leaves(example_trees[2]))
print(jtree.tree_leaves(example_trees[2], is_leaf=check_list))
print(jtree.tree_leaves(example_trees[2], is_leaf=check_dict))

output:
[1, 2, 3, 4, 5]
[[1, {'k1': 2, 'k2': (3, 4)}, 5]]
[1, {'k1': 2, 'k2': (3, 4)}, 5]

https://ithelp.ithome.com.tw/upload/images/20221003/20129616iA9X1WLBhF.png

tree_map 的運作方式,基本上和 Python map 一樣,套用函式 f 在所有的葉節點上,回傳一個結構相同,但包含新的值的 Pytree。

list_of_lists = [
    [1, 2, 3],
    [1, 2],
    [1, 2, 3, 4]
]
 
jax.tree_map(lambda x: x*2, list_of_lists)

output:
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

函式 f 也可以被套用於多個 Pytree 的葉節點上,例如:

another_list_of_lists = list_of_lists
jax.tree_map(lambda x, y: x-y, list_of_lists, another_list_of_lists)

output:
[[0, 0, 0], [0, 0], [0, 0, 0, 0]]

使用時要注意,函式 f 輸入參數的數量,必須和呼叫 tree_map() 時輸入的 pytree 數量一致,而且這些輸入的 Pytree 必須要有一致的結構,否則會產生執行時錯誤。

is_leaf= 的用法,可以參考前面 tree_leaves() API。

https://ithelp.ithome.com.tw/upload/images/20221003/201296165IRut1goMN.png

回傳 Pytree 的結構定義,容器以 Python 的語法來顯示,葉節點則以 * 表示。

for pytree in example_trees:
  struct = jtree.tree_structure(pytree)
  print(f"{repr(pytree):<45} : {struct}")

output:
https://ithelp.ithome.com.tw/upload/images/20221003/201296164wxfrsob4D.png

目前為止,對於 Pytree 的介紹可能還是太抽象了,稍後老頭會打造一個簡單的神經網路模型, 以實際的例子來說明 Pytree 的用法。

註:

[29.1] 參考 What is a pytree
[29.2] 可參考 jax.tree_util package


上一篇
JAX 好好玩 (28) : vmap 自動向量化
下一篇
JAX 好好玩 (30) : 綜合演練 – 簡單的 MLP
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言