前情提要: 昨天把一些基礎的操作都講了,其實還有更多寫法,只是我自己最常用rearrange。
今天舉個當初我看完覺得一頭霧水的例子,是在語音增強和分離蠻常用到的架構,架構我們不談,我們只針對程式的可讀性做討論。
程式擷取(https://github.com/JusperLee/Dual-Path-RNN-Pytorch/blob/master/model/model_rnn.py )
在當中你會看到這段,做intra RNN的運算,當初的我對維度還不是那麼熟悉,看到以下code一頭霧水,雖然有註解,但維度操作好多。
def forward(self, x):
'''
x: [B, N, K, S]
out: [Spks, B, N, K, S]
'''
B, N, K, S = x.shape
# intra RNN
# [BS, K, N]
intra_rnn = x.permute(0, 3, 2, 1).contiguous().view(B*S, K, N)
# [BS, K, H]
intra_rnn, _ = self.intra_rnn(intra_rnn)
# [BS, K, N]
intra_rnn = self.intra_linear(intra_rnn.contiguous().view(B*S*K, -1)).view(B*S, K, -1)
# [B, S, K, N]
intra_rnn = intra_rnn.view(B, S, K, N)
# [B, N, K, S]
intra_rnn = intra_rnn.permute(0, 3, 2, 1).contiguous()
intra_rnn = self.intra_norm(intra_rnn)
這裡我將上面的code,每一小段對應一個p的function做測試,這裡忽略self.intra_rnn, self.intra_linear的計算
from einops import rearrange
import torch
x = torch.randn(2, 3, 4, 5)
B, N, K, S = x.shape
def p1(x):
t1 = x.permute(0, 3, 2, 1).contiguous().view(B * S, K, N)
t2 = rearrange(x, 'b n k s -> (b s) k n')
print(f'p1: {t1.shape}, {t2.shape}')
assert torch.equal(t1, t2), "The tensors t1 and t2 are not equal."
return t1
def p2(x):
t1 = x.contiguous().view(B * S * K, -1)
t2 = rearrange(x, 'bs k n -> (bs k) n')
print(f'p2: {t1.shape}, {t2.shape}')
assert torch.equal(t1, t2), "The tensors t1 and t2 are not equal."
return t1
def p3(x):
t1 = x.view(B*S, K, -1)
# 這裡因為要將一個分解為兩個,所以要給bs或k當中一個值
t2 = rearrange(x, '(bs k) n -> bs k n', k = K)
print(f'p3: {t1.shape}, {t2.shape}')
assert torch.equal(t1, t2), "The tensors t1 and t2 are not equal."
return t1
def p4(x):
t1 = x.view(B, S, K, N).permute(0, 3, 2, 1).contiguous()
t2 = rearrange(x, '(b s) k n -> b n k s', b = B)
print(f'p4: {t1.shape}, {t2.shape}')
assert torch.equal(t1, t2), "The tensors t1 and t2 are not equal."
return t1
可以將原本的程式修改成下面,看起來應該精簡很多。
def all(self, x):
B, N, K, S = x.shape
intra_rnn = rearrange(x, 'b n k s -> (b s) k n')
intra_rnn, _ = self.intra_rnn(intra_rnn)
intra_rnn = self.intra_linear(rearrange(intra_rnn, 'bs k n -> (bs k) n'))
# 這行可以取代下面兩行
intra_rnn = rearrange(intra_rnn, '(b s k) n -> b n k s', b = B, k = K)
# intra_rnn = rearrange(intra_rnn, '(bs k) n -> bs k n', k = K)
# intra_rnn = rearrange(intra_rnn, '(b s) k n -> b n k s', b = B)
今天實際拿我先前看過的程式來做個舉例,應該可以看的出來精簡很多,接下來只需要多練習,之後寫model可以嘗試看看以這種方式表達。
今天就先到這裡囉~~