sft訓練的程式碼
sft.py
以下為簡化過的程式碼,X
, Y
, loss_mask
的內容如何產生已經在前面介紹SFT dataset的時候介紹過了,所以SFT訓練部份主要的重點剩下Loss的計算:
ignore_index=0
: 忽略掉padding_token_id
對應的Lossreduce=False
: 這邊只是暫時不做reduce,然而後面乘完loss_mask
以後還是會做reduce,這樣每一個訓練樣本對於模型的影響才會是一致的loss = torch.sum(loss*loss_mask)/loss_mask.sum()
torch.sum(loss*loss_mask)
: 僅計算answer
對應tokens的losstorch.sum(...)/loss_mask.sum()
: 最後對loss做mean reduce
def train_epoch(epoch):
start_time=time.time()
for step, (X, Y,loss_mask) in enumerate(train_loader):
......
logits = model(X, Y)
loss = F.cross_entropy(logits, Y, ignore_index=0,reduce=False)
loss = torch.sum(loss*loss_mask)/loss_mask.sum()
......
if __name__=="__main__":
max_epoch = 10
batch_size = 32
# model
max_seq_len = 512
dim = 512
n_layers = 8
n_heads = 8
multiple_of = 32
......