前情提要: 基本上我已經把我想講的都講完了,其中包含pytorch lightning, 爬蟲, fastapi, deploy model,接下來會再看要講甚麼,讓碩班同學更好入手。
今天要介紹的是einops這套,怎麼樣利用這套來增加可讀性。
我們先回顧一下先前寫得簡單的model,如果沒有特別寫註解,可能還要想一下他的維度會變甚麼,如果model當中有更多的reshape, transpose, permute, 加上沒註解的話,那根本頭暈眼花,此時我們就可以用einops來幫助我們拉。
def forward(self, x):
'''
x: [B, C, W, H]
B: batch size
C: channel
W: Width
H: Hight
'''
batch_size = x.size(0)
x = x.view(batch_size, -1) # [B, C, W, H] -> [B, C * H * W]
x = self.model(x)
return x
pip install einops
練習的時候我蠻常做下面的事,主要就是測試兩個運算一不一樣,最後用個assert來確認,assert在很多程式會用來檢查shape對不對,或用於測試時使用,另一個是torch.testing.assert_close。
你會發現在rearrange裡面就已經把當初的註解的寫進去了,主要就是把箭頭前轉換成箭頭後,()代表這幾個合併成一個維度
from einops import rearrange
import torch
import functools
assert_equal = functools.partial(torch.testing.assert_close, rtol = 0, atol = 0)
# assert_equal(1e-9, 1e-10)
def test_rearrange(x):
batch_size = x.size(0)
t1 = x.view(batch_size, -1)
t2 = rearrange(x, 'b c h w -> b (c h w)')
print(t1.size(), t2.size())
# print(t1, t2)
assert_equal(t1, t2)
assert torch.equal(t1, t2), "The tensors t1 and t2 are not equal."
if __name__ == "__main__":
x = torch.randn(2, 3, 4, 5)
test_rearrange(x)
如果把上面的assert_equal(1e-9, 1e-10)註解拿掉,就會看到以下這樣的訊息,這個可以用於你想實作一些東西,然後看跟別人的結果一不一樣,或看差多少,就可以採用這個。
from einops import rearrange
import torch
import functools
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
# assert_equal(1e-9, 1e-10)
def test_rearrange(x):
batch_size = x.size(0)
t1 = x.view(batch_size, -1)
t2 = rearrange(x, 'b c h w -> b (c h w)')
# print(t1, t2)
assert_equal(t1, t2)
assert torch.equal(t1, t2), "The tensors t1 and t2 are not equal."
def test_permute(x):
# 可以多個維度交換
# permute(0, 3, 1, 2)代表以下
# 维度 0(b)保持在第 0 位
# 维度 3(w)移動到第 1 位
# 维度 1(c)移動到第 2 位
# 维度 2(h)移動到第 3 位
t1 = x.permute(0, 3, 1, 2) # [b, c, h, w] -> [b, w, c, h]
t2 = rearrange(x, 'b c h w -> b w c h')
print(t1.size(), t2.size())
assert torch.equal(t1, t2)
def test_transpose(x):
# 只能兩個維度交換
t1 = x.transpose(0, 1) # [b, c, h, w] -> [c, b, h, w]
t2 = rearrange(x, 'b c h w -> c b h w')
print(t1.size(), t2.size())
assert torch.equal(t1, t2)
if __name__ == "__main__":
x = torch.randn(2, 3, 4, 5)
test_rearrange(x)
test_permute(x)
test_transpose(x)
今天就先到這裡囉~
也可以到github上面看更多應用