iT邦幫忙

2024 iThome 鐵人賽

DAY 18
0
自我挑戰組

菜鳥AI工程師給碩班學弟妹的挑戰系列 第 18

[Day 18] einops增加model可讀性 - 實際舉例

  • 分享至 

  • xImage
  •  

前情提要: 昨天把一些基礎的操作都講了,其實還有更多寫法,只是我自己最常用rearrange。

今天舉個當初我看完覺得一頭霧水的例子,是在語音增強和分離蠻常用到的架構,架構我們不談,我們只針對程式的可讀性做討論。

程式擷取(https://github.com/JusperLee/Dual-Path-RNN-Pytorch/blob/master/model/model_rnn.py )

1. 程式片段

在當中你會看到這段,做intra RNN的運算,當初的我對維度還不是那麼熟悉,看到以下code一頭霧水,雖然有註解,但維度操作好多。

    def forward(self, x):
        '''
           x: [B, N, K, S]
           out: [Spks, B, N, K, S]
        '''
        B, N, K, S = x.shape
        # intra RNN
        # [BS, K, N]
        intra_rnn = x.permute(0, 3, 2, 1).contiguous().view(B*S, K, N)
        # [BS, K, H]
        intra_rnn, _ = self.intra_rnn(intra_rnn)
        # [BS, K, N]
        intra_rnn = self.intra_linear(intra_rnn.contiguous().view(B*S*K, -1)).view(B*S, K, -1)
        # [B, S, K, N]
        intra_rnn = intra_rnn.view(B, S, K, N)
        # [B, N, K, S]
        intra_rnn = intra_rnn.permute(0, 3, 2, 1).contiguous()
        intra_rnn = self.intra_norm(intra_rnn)

2. 分段解析

這裡我將上面的code,每一小段對應一個p的function做測試,這裡忽略self.intra_rnn, self.intra_linear的計算

from einops import rearrange
import torch    

x = torch.randn(2, 3, 4, 5)
B, N, K, S = x.shape

def p1(x):
    t1 = x.permute(0, 3, 2, 1).contiguous().view(B * S, K, N)
    t2 = rearrange(x, 'b n k s -> (b s) k n')

    print(f'p1: {t1.shape}, {t2.shape}')
    assert torch.equal(t1, t2), "The tensors t1 and t2 are not equal."
    return t1

def p2(x):
    t1 = x.contiguous().view(B * S * K, -1)
    t2 = rearrange(x, 'bs k n -> (bs k) n')

    print(f'p2: {t1.shape}, {t2.shape}')
    assert torch.equal(t1, t2), "The tensors t1 and t2 are not equal."
    return t1
    
def p3(x):
    t1 = x.view(B*S, K, -1)
    # 這裡因為要將一個分解為兩個,所以要給bs或k當中一個值
    t2 = rearrange(x, '(bs k) n -> bs k n', k = K) 

    print(f'p3: {t1.shape}, {t2.shape}')
    assert torch.equal(t1, t2), "The tensors t1 and t2 are not equal."
    return t1
    
def p4(x):
    t1 = x.view(B, S, K, N).permute(0, 3, 2, 1).contiguous()
    t2 = rearrange(x, '(b s) k n -> b n k s', b = B) 

    print(f'p4: {t1.shape}, {t2.shape}')
    assert torch.equal(t1, t2), "The tensors t1 and t2 are not equal."
    return t1

3. 全部程式

可以將原本的程式修改成下面,看起來應該精簡很多。

def all(self, x):
    B, N, K, S = x.shape
    intra_rnn = rearrange(x, 'b n k s -> (b s) k n')
    intra_rnn, _ = self.intra_rnn(intra_rnn)
    intra_rnn = self.intra_linear(rearrange(intra_rnn, 'bs k n -> (bs k) n'))
    
    # 這行可以取代下面兩行
    intra_rnn = rearrange(intra_rnn, '(b s k) n -> b n k s', b = B, k = K)
    # intra_rnn = rearrange(intra_rnn, '(bs k) n -> bs k n', k = K) 
    # intra_rnn = rearrange(intra_rnn, '(b s) k n -> b n k s', b = B) 

今天實際拿我先前看過的程式來做個舉例,應該可以看的出來精簡很多,接下來只需要多練習,之後寫model可以嘗試看看以這種方式表達。

今天就先到這裡囉~~


上一篇
[Day 17] einops增加model可讀性
下一篇
[Day 19] 實作model理解觀念 - 1
系列文
菜鳥AI工程師給碩班學弟妹的挑戰30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言