由於模型訓練進展很慢,沒有什麼能更新的,今天唯一做的事情就是修改了一下訓練時的程式,每經過一定的iteration就儲存一次最後的checkpoint
,否則模型訓練一個epoch實在太久,必須要能夠resume上次的訓練。
def save_ckpt(model, model_args, iter_num, best_val_loss, epoch, ckpt_path):
"""
將模型、迭代次數、最佳驗證損失和訓練的epoch保存到checkpoint文件中。
Parameters:
model (torch.nn.Module): 要保存的模型
iter_num (int): 目前的迭代次數
best_val_loss (float): 最佳驗證損失
epoch (int): 目前的訓練epoch數
ckpt_path (str): checkpoint文件的路徑
"""
checkpoint = {
'model': model.state_dict(),
'model_args': model_args,
'iter_num': iter_num,
'best_val_loss': best_val_loss,
'epoch': epoch
}
torch.save(checkpoint, ckpt_path)
print(f'Checkpoint saved to {ckpt_path}')