今天把sft.py
也改成了可以resume training,其實看到現在,對於整個訓練的細節了解後,可以歸納出pretrain
與sft
在訓練的程式上的不同點:
PretrainDataset
vs SFTDataset
./data/pretrain_data.bin
vs ./data/sft_data.csv
(X,Y)
vs (X,Y,loss_mask)
raw_model.last_loss
loss_mask
計算loss
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=0,reduce=False)
loss_mask = loss_mask.view(-1)
loss = torch.sum(loss*loss_mask)/loss_mask.sum()
除此之外的程式碼都一模一樣,為了重本這樣重複兩份code改來改去容易產生bug,接下來打算花一些時間修改這樣pretrain
與sft
的訓練都使用同一個程式。