iT邦幫忙

2024 iThome 鐵人賽

DAY 17
0
自我挑戰組

菜鳥AI工程師給碩班學弟妹的挑戰系列 第 17

[Day 17] einops增加model可讀性

  • 分享至 

  • xImage
  •  

前情提要: 基本上我已經把我想講的都講完了,其中包含pytorch lightning, 爬蟲, fastapi, deploy model,接下來會再看要講甚麼,讓碩班同學更好入手。

1. einops ( https://github.com/arogozhnikov/einops )

今天要介紹的是einops這套,怎麼樣利用這套來增加可讀性。
我們先回顧一下先前寫得簡單的model,如果沒有特別寫註解,可能還要想一下他的維度會變甚麼,如果model當中有更多的reshape, transpose, permute, 加上沒註解的話,那根本頭暈眼花,此時我們就可以用einops來幫助我們拉。

    def forward(self, x):
        '''
            x: [B, C, W, H]

            B: batch size
            C: channel
            W: Width
            H: Hight
        '''
        batch_size = x.size(0)
        x = x.view(batch_size, -1) # [B, C, W, H] -> [B, C * H * W]
        x = self.model(x)
        
        return x

2. einops 安裝

pip install einops

3. 基礎操作 - 1

練習的時候我蠻常做下面的事,主要就是測試兩個運算一不一樣,最後用個assert來確認,assert在很多程式會用來檢查shape對不對,或用於測試時使用,另一個是torch.testing.assert_close。

你會發現在rearrange裡面就已經把當初的註解的寫進去了,主要就是把箭頭前轉換成箭頭後,()代表這幾個合併成一個維度

from einops import rearrange
import torch

import functools
assert_equal = functools.partial(torch.testing.assert_close, rtol = 0, atol = 0)
# assert_equal(1e-9, 1e-10)

def test_rearrange(x):
    batch_size = x.size(0)
    t1 = x.view(batch_size, -1)
    t2 = rearrange(x, 'b c h w -> b (c h w)')
    print(t1.size(), t2.size())
    # print(t1, t2)
    assert_equal(t1, t2)
    assert torch.equal(t1, t2), "The tensors t1 and t2 are not equal."

if __name__ == "__main__":
    x = torch.randn(2, 3, 4, 5)
    test_rearrange(x)

https://ithelp.ithome.com.tw/upload/images/20240821/20168446QxYegV8Fbc.png

如果把上面的assert_equal(1e-9, 1e-10)註解拿掉,就會看到以下這樣的訊息,這個可以用於你想實作一些東西,然後看跟別人的結果一不一樣,或看差多少,就可以採用這個。
https://ithelp.ithome.com.tw/upload/images/20240821/20168446aXg760pZys.png

4. 基礎操作 - 2

from einops import rearrange
import torch

import functools
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
# assert_equal(1e-9, 1e-10)

def test_rearrange(x):
    batch_size = x.size(0)
    t1 = x.view(batch_size, -1)
    t2 = rearrange(x, 'b c h w -> b (c h w)')

    # print(t1, t2)
    assert_equal(t1, t2)
    assert torch.equal(t1, t2), "The tensors t1 and t2 are not equal."

def test_permute(x):
    # 可以多個維度交換
    # permute(0, 3, 1, 2)代表以下
    # 维度 0(b)保持在第 0 位
    # 维度 3(w)移動到第 1 位
    # 维度 1(c)移動到第 2 位
    # 维度 2(h)移動到第 3 位
    t1 = x.permute(0, 3, 1, 2) # [b, c, h, w] -> [b, w, c, h]
    t2 = rearrange(x, 'b c h w -> b w c h')

    print(t1.size(), t2.size())
    assert torch.equal(t1, t2)

def test_transpose(x):
    # 只能兩個維度交換
    t1 = x.transpose(0, 1) # [b, c, h, w] -> [c, b, h, w]
    t2 = rearrange(x, 'b c h w -> c b h w')

    print(t1.size(), t2.size())
    assert torch.equal(t1, t2)

if __name__ == "__main__":
    x = torch.randn(2, 3, 4, 5)
    test_rearrange(x)
    test_permute(x)
    test_transpose(x)

今天就先到這裡囉~
也可以到github上面看更多應用


上一篇
[Day 16] Dockerfile
下一篇
[Day 18] einops增加model可讀性 - 實際舉例
系列文
菜鳥AI工程師給碩班學弟妹的挑戰30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言