一. 建立decoder
昨天已建立完decoder的部分,我上面有一些註釋,希望多少可以幫助理解程式碼
class TransformerDecoder(nn.Module):
def __init__(self, hidden_dim, feedforward_dim, n_dec_layers, n_attn_heads, dropout,
dec_voca_length, max_pos_length , device , skip_encoder_attn = False):
"""
hidden_dim: embedding 大小
feedforward_dim: feedforward大小
n_dec_layers: 幾個 decoder Layers
n_attn_heads: 幾個 attention
dropout: dropout
dec_voca_length: 輸出的字典大小(英文)
max_pos_length: dec_max_len
"""
super().__init__()
self.device = device
# token embedding
self.dec_tok_embedding = nn.Embedding(dec_voca_length, hidden_dim)
# position embedding
self.dec_pos_embedding = nn.Embedding(max_pos_length, hidden_dim)
# 建立 n_dec_layers 個 TransformerDecoderLayer 層
self.transformer_decoder_layers = nn.ModuleList([TransformerDecoderLayer(
hidden_dim,
feedforward_dim,
n_dec_layers,
n_attn_heads,
dropout,
device, skip_encoder_attn) for _ in range(n_dec_layers)])
# 輸出層 輸出 vocabulary 個長度
self.full_conn_out = nn.Linear(hidden_dim, dec_voca_length)
self.dropout = nn.Dropout(dropout)
self.scale = torch.sqrt(torch.FloatTensor([hidden_dim])).to(device)
def forward(self, dec_seq, enc_hidden , dec_mask, enc_mask):
"""
dec_seq: [batch_size, trg_len]
enc_hidden: [batch_size, src_len, hid_dim]
dec_mask: [batch_size, trg_len]
enc_mask: [batch_size, src_len]
"""
batch_size = dec_seq.shape[0]
dec_len = dec_seq.shape[1]
# pos [batch_size, trg_len]
pos = torch.arange(0, dec_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
# 將 decoder token embedding 加上 decoder postion embedding
# dec_seq [batch_size, trg_len, hid_dim]
dec_seq = self.dropout(self.dec_tok_embedding(dec_seq) + self.dec_pos_embedding(pos))
for layer in self.transformer_decoder_layers:
dec_seq, encoder_decoder_attention , decoder_self_attention = layer(dec_seq, enc_hidden, dec_mask, enc_mask)
#dec_seq 輸出 tensor 形狀 [batch_size, trg_len, hid_dim]
#attention 輸出 tensor 形狀 [batch_size, n_heads, trg_len, src_len]
output = self.full_conn_out(dec_seq)
#output tensor 形狀 [batch size, trg len, output dim]
return output, encoder_decoder_attention , decoder_self_attention
class TransformerDecoderLayer(nn.Module):
def __init__(self, hidden_dim, feedforward_dim, n_dec_layers, n_attn_heads, dropout, device, skip_encoder_attn=False):
"""
hidden_dim: embedding 大小
feedforward_dim: feedforward大小
n_dec_layers: 幾個 decoder Layers
n_attn_heads: 幾個 attention
dropout: dropout
"""
super().__init__()
self.skip_encoder_attn = skip_encoder_attn
self.self_attention_sublayer = MultiHeadAttentionSubLayer(hidden_dim, n_attn_heads, dropout, device)
self.self_attn_layernorm = nn.LayerNorm(hidden_dim)
if not skip_encoder_attn:
self.encoder_attention_sublayer = MultiHeadAttentionSubLayer(hidden_dim, n_attn_heads, dropout, device)
self.encoder_attn_layernorm = nn.LayerNorm(hidden_dim)
self.positionwise_feedforward = PosFeedForwardSubLayer(hidden_dim,feedforward_dim ,dropout)
self.feedforward_layernorm = nn.LayerNorm(hidden_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, dec_seq, enc_hidden , dec_mask, enc_mask):
"""
dec_seq: [batch_size, trg_len, hid_dim]
enc_hidden: [batch_size, src_len, hid_dim]
dec_mask: [batch_size, trg_len]
enc_mask: [batch_size, src_len]
"""
# self attention 子層
_dec_seq, decoder_self_attention = self.self_attention_sublayer(dec_seq, dec_seq, dec_seq, dec_mask)
# dropout, residual connection and layer norm
# dec_seq 輸出 tensor 形狀 [batch_size, trg_len, hid_dim] => decoder的 q
dec_seq = self.self_attn_layernorm(dec_seq + self.dropout(_dec_seq))
# 需不需要建立 encoder attention 層
if not self.skip_encoder_attn:
#encoder attention
_dec_seq, encoder_decoder_attention = self.encoder_attention_sublayer(dec_seq, enc_hidden, enc_hidden, enc_mask)
# dropout, residual connection and layer norm
# dec_seq [batch_size, trg_len, hid_dim]
dec_seq = self.encoder_attn_layernorm(dec_seq + self.dropout(_dec_seq))
else:
encoder_decoder_attention = None
# positionwise feedforward
_dec_seq = self.positionwise_feedforward(dec_seq)
#dropout, residual and layer norm
dec_seq = self.feedforward_layernorm(dec_seq + self.dropout(_dec_seq))
# dec_seq [batch_size, trg_len, hid dim]
# attention [batch_size, n_heads, trg_len, src_len]
return dec_seq, encoder_decoder_attention , decoder_self_attention
## 啟動參數
## encoder Transformer encoder
## decoder Transformer decoder
## src_pad_idx encoder padding index
## trg_pad_idx decoder padding index
## device
## 輸入值
## src 中文被翻譯句
## trg 英文翻譯目標句
class Transformer(nn.Module):
def __init__(self,encoder, decoder, src_pad_idx, trg_pad_idx, device):
"""
encoder: Transformer encoder
decoder: Transformer decoder
src_pad_idx: encoder padding index
trg_pad_idx: decoder padding index
"""
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.src_pad_idx = src_pad_idx
self.trg_pad_idx = trg_pad_idx
self.device = device
def make_src_mask(self, src):
"""
src: [batch size, src len]
"""
# 做Padding mask的準備 讓attention對padding的idx影響變小
src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
# src_mask [batch size, 1, 1, src len]
return src_mask
def make_trg_mask(self, trg):
"""
trg: [batch size, trg len]
"""
# Padding mask
# trg_pad_mask [batch size, 1, 1, trg len]
trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
trg_len = trg.shape[1]
# Look ahead mask
#trg_sub_mask = [trg len, trg len]
trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device)).bool()
#trg_mask = [batch size, 1, trg len, trg len]
trg_mask = trg_pad_mask & trg_sub_mask
return trg_mask
def forward(self, src, trg):
"""
src [batch size, src len]
trg [batch size, trg len]
"""
src_mask = self.make_src_mask(src)
trg_mask = self.make_trg_mask(trg)
#src_mask = [batch size, 1, 1, src_len]
#trg_mask = [batch size, 1, trg_len, trg_len]
# enc_src = [batch_size, src_len, hid_dim]
enc_src , encoder_self_attention = self.encoder(src, src_mask)
output, encoder_decoder_attention , decoder_self_attention = self.decoder(trg, enc_src, trg_mask, src_mask)
#output = [batch_size, trg_len, output_dim]
#attention = [batch_size, n_heads, trg_len, src_len]
return output, encoder_decoder_attention , encoder_self_attention , decoder_self_attention
二. 建立模型與訓練
INPUT_DIM = len(source_ch.vocab)
OUTPUT_DIM = len(target_en.vocab)
MAX_SENT_LENGTH = 40
HID_DIM = 256
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_FF_DIM = 512
DEC_FF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1
LEARNING_RATE = 0.0005
SRC_PAD_IDX = source_ch.vocab.stoi[source_ch.pad_token]
TRG_PAD_IDX = target_en.vocab.stoi[target_en.pad_token]
enc = TransformerEncoder(HID_DIM, ENC_FF_DIM, ENC_LAYERS, ENC_HEADS, ENC_DROPOUT,INPUT_DIM, MAX_SENT_LENGTH,device)
dec = TransformerDecoder(HID_DIM, DEC_FF_DIM,
DEC_LAYERS,
DEC_HEADS,
DEC_DROPOUT,
OUTPUT_DIM, MAX_SENT_LENGTH,
device)
model = Transformer(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)
def train(model, iterator, optimizer, criterion, clip):
model.train()
epoch_loss = 0
for i, batch in enumerate(iterator):
src = batch.src
trg = batch.trg
optimizer.zero_grad()
output, _ , _ , _ = model(src, trg[:,:-1])
#output = [batch size, trg len - 1, output dim]
#trg = [batch size, trg len]
output_dim = output.shape[-1]
output = output.contiguous().view(-1, output_dim)
trg = trg[:,1:].contiguous().view(-1)
#output = [batch size * trg len - 1, output dim]
#trg = [batch size * trg len - 1]
loss = criterion(output, trg)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(iterator)
def evaluate(model, iterator, criterion):
model.eval()
epoch_loss = 0
with torch.no_grad():
for i, batch in enumerate(iterator):
src = batch.src
trg = batch.trg
output, _ , _ , _= model(src, trg[:,:-1])
#output = [batch size, trg len - 1, output dim]
#trg = [batch size, trg len]
output_dim = output.shape[-1]
output = output.contiguous().view(-1, output_dim)
trg = trg[:,1:].contiguous().view(-1)
#output = [batch size * trg len - 1, output dim]
#trg = [batch size * trg len - 1]
loss = criterion(output, trg)
epoch_loss += loss.item()
return epoch_loss / len(iterator)
N_EPOCHS = 30
CLIP = 1
best_valid_loss = 9999999
for epoch in range(N_EPOCHS):
start_time = time.time()
train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
valid_loss = evaluate(model, valid_iterator, criterion)
end_time = time.time()
#epoch_mins, epoch_secs = epoch_time(start_time, end_time)
torch.save(model.state_dict(), model_dir + 'model-{}.pt'.format(epoch))
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), model_dir + 'best-model.pt')
print ("Epoch {} training time: {:.2f} sec Training Loss: {:.3f} , Valiation Loss: {:.3f}".format( epoch , end_time - start_time , train_loss , valid_loss))
def translate_sentence(sentence, src_field, trg_field, model, device, max_len = 50):
model.eval()
tokens = [token.lower() for token in sentence]
tokens = [src_field.init_token] + tokens + [src_field.eos_token]
src_indexes = [src_field.vocab.stoi[token] for token in tokens]
src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)
src_mask = model.make_src_mask(src_tensor)
with torch.no_grad():
enc_src , encoder_self_attention = model.encoder(src_tensor, src_mask)
# 翻譯結果句 先以 init token 開頭
trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]
for i in range(max_len):
trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
trg_mask = model.make_trg_mask(trg_tensor)
with torch.no_grad():
output, encoder_decoder_attention , decoder_self_attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
# 取得模型最佳預測
pred_token = output.argmax(2)[:,-1].item()
# 放入翻譯結果句
trg_indexes.append(pred_token)
# 碰到 eos 結束
if pred_token == trg_field.vocab.stoi[trg_field.eos_token]:
break
trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes]
return trg_tokens[1:], encoder_decoder_attention , encoder_self_attention , decoder_self_attention
# 隨便找一個train的index
sample_index = 1250
src = vars(train_dataset.examples[sample_index])['src']
trg = vars(train_dataset.examples[sample_index])['trg']
print('src = ' , src)
print('trg = ' , trg)
translation, encoder_decoder_attention , encoder_self_attention , decoder_self_attention = translate_sentence(src, source_ch, target_en, model, device)
print('翻譯 = ' ,translation)
# output
# src = ['你', '为', '什', '么', '会', '认', '为', '我', '在', '想', '你']
# trg = ['why', 'do', 'you', 'think', 'i', "'m", 'thinking', 'about', 'you', '?']
# 翻譯 = ['why', 'do', 'you', 'think', 'i', "'m", 'thinking', 'about', 'you', '?', '<eos>']
以上就是transformer的code了明天會開始說明BERT