前情提要: 昨天講解了DEMUCS encoder的作法,盡量寫成每一個block方便修改及增加可讀性,並且在unit test時比較方便。
activation參考:
https://www.geeksforgeeks.org/choosing-the-right-activation-function-for-your-neural-network/
lstm參考:
https://blog.csdn.net/baidu_38963740/article/details/117197619
目前的Decoder還沒有加入skip connection,只是照昨天的樣式先做。
這裡可以照原本的方式用insert,也可以用append最後在[::-1]將list做翻轉的操作,看個人喜好。
另外要小心的是最後一層不加ReLU,主要是最後一層就是輸出了,也就是輸出audio,如果過了ReLU,那原本負號的值都變成0了,一定不符合我們想要的。
class DoublConvTr(nn.Module):
def __init__(self, hidden, chout, K, S, last = False):
super().__init__()
layers = [
nn.Conv1d(hidden, hidden * 2, 1, 1),
nn.GLU(1),
nn.ConvTranspose1d(hidden, chout, K, S)
]
if not last: # 最後一層不加ReLU
layers.append(nn.ReLU())
self.DoublConvTr = nn.Sequential(*layers)
def forward(self, x):
return self.DoublConvTr(x)
class Decoder(nn.Module):
def __init__(
self,
chout = 1,
hidden = 48,
kernel_size = 8,
stride = 4,
growth = 2,
depth = 5,
):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.depth = depth
self.decoder = nn.ModuleList()
for idx in range(depth):
# last = True if idx == 0 else False
# self.decoder.insert(0, DoublConvTr(hidden, chout, kernel_size, stride, last))
last = True if idx == depth - 1 else False
self.decoder.append(DoublConvTr(hidden, chout, kernel_size, stride, last))
chout = hidden
hidden *= growth
self.decoder = self.decoder[::-1]
def forward(self, x):
for idx, dec in enumerate(self.decoder):
x = dec(x)
print(f'idx: {idx}, x: {x.size()}')
return x
這邊直接使用雙向的lstm,再透過linear讓維度變回原本的dim。
class Bottleneck(nn.Module):
def __init__(
self,
dim,
num_layers = 2,
bi = True,
):
super().__init__()
self.lstm = nn.LSTM(dim, dim, num_layers, bidirectional = bi, batch_first = True)
self.linear = nn.Linear(2 * dim, dim)
def forward(self, x, hidden=None):
x, hidden = self.lstm(x, hidden)
x = self.linear(x)
return x
最後將我們目前有的都添加到Demucs裡面,分別是encoder bottleneck decoder ,並在forward呼叫每一個做運算,此時還沒有加上skip connection,但我們可以先測試整個model是不是可以跑起來,以及他的維度有沒有符合預期。
class Demucs(nn.Module):
def __init__(
self,
chin=1,
chout=1,
hidden=48,
depth=5,
kernel_size=8,
stride=4,
resample=4,
growth=2,
normalize=True,
sample_rate=16_000
):
super().__init__()
self.encoder = Encoder(chin, hidden, kernel_size, stride, growth, depth)
self.bottleneck = Bottleneck(self.encoder.final_hidden)
self.decoder = Decoder(chout, hidden, kernel_size, stride, growth, depth)
def forward(self, x):
x = self.encoder(x)
x = rearrange(x, 'b c l -> b l c')
x = self.bottleneck(x)
x = rearrange(x, 'b l c -> b c l')
x = self.decoder(x)
return x
if __name__ == "__main__":
x = torch.rand(2, 1, 16000)
model = Demucs()
print(model)
x = model(x)
print(x.size())
今天就先到這裡囉~~
可以發現最後的Demucs其實很簡單,在很多git你也會發現,最後的主model其實都很簡單,因為都呼叫寫好的class,最後只是拼裝再一起而已。
我自己剛開始學的時候會覺得何必這麼麻煩,就全部寫一起就好啦,但當後續要修改model的時候發現實在不方便,也不好測試你修改的對不對,所以就花了很多時間去適應這種寫法。