sft訓練的程式碼
sft.py
以下為簡化過的程式碼,X, Y, loss_mask的內容如何產生已經在前面介紹SFT dataset的時候介紹過了,所以SFT訓練部份主要的重點剩下Loss的計算:
ignore_index=0: 忽略掉padding_token_id對應的Lossreduce=False: 這邊只是暫時不做reduce,然而後面乘完loss_mask以後還是會做reduce,這樣每一個訓練樣本對於模型的影響才會是一致的loss = torch.sum(loss*loss_mask)/loss_mask.sum()
torch.sum(loss*loss_mask): 僅計算answer對應tokens的losstorch.sum(...)/loss_mask.sum(): 最後對loss做mean reduce
def train_epoch(epoch):
    start_time=time.time()
    for step, (X, Y,loss_mask) in enumerate(train_loader):
        ......
        logits = model(X, Y)
        loss = F.cross_entropy(logits, Y, ignore_index=0,reduce=False)
        loss = torch.sum(loss*loss_mask)/loss_mask.sum()
        ......
if __name__=="__main__":
    max_epoch = 10
    batch_size = 32
    # model
    max_seq_len = 512
    dim = 512
    n_layers = 8
    n_heads = 8
    multiple_of = 32
    ......