iT邦幫忙

2023 iThome 鐵人賽

DAY 18
0
AI & Data

30天把AI知識傳授給女友系列 第 18

Day 18 存取 Pytroch 訓練好的模型和權重

  • 分享至 

  • xImage
  •  

今天先介紹怎麼把模型存下來,神經網路在訓練的時候所有權重都是存在記憶體中的,只要程式或電腦關掉這些資料就會消失,因此我們必須將這些訓練好的資訊存到硬碟中。

PyTorch 模型將學習到的參數儲存在內部狀態字典中,稱為state_dict。這些可以透過以下方法保存torch.save

torch.save(model.state_dict(), 'cat_vs_rabbit_cls_v1.pth')

要載入模型權重,您需要先建立相同架構的物件,然後使用load_state_dict()方法載入參數。利用to.(device)將模型放到同個運算機器上(如果dataloader和model不同機器會報錯哦)

reload_model = NeuralNetwork(3*224*224)
reload_model.load_state_dict(torch.load('cat_vs_rabbit_cls_v1.pth'))
reload_model.to(device)
reload_model.eval()

https://ithelp.ithome.com.tw/upload/images/20230923/20153503uV71kvToBV.png

跑測試看看模型有沒有載入成功,我們可以利用測試的函數,不過只需要跑一次就好:

def test_loop(dataloader, model):
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
epochs = 1
for t in range(epochs):
    test_loop(val_dataloader, reload_model)

可以看到準確率有 72% 左右 (因為我有重新訓練模型,所以和昨天準確率不同是正常的)
https://ithelp.ithome.com.tw/upload/images/20230924/20153503LgzPI6Ug6T.png

結語

今天先介紹怎麼存取模型,明天會開始試著優化我們的架構。


上一篇
Day 17 煉丹爐開始煉丹啦 - 訓練神經網路
下一篇
Day 19 卷一卷或許準一點 - 卷積神經網路
系列文
30天把AI知識傳授給女友30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言