iT邦幫忙

2021 iThome 鐵人賽

DAY 18
1
AI & Data

新手一起來Azure上玩 NNI (auto-ML的一種)系列 第 18

模型的內容08 test()

這個章節,我們將談到 test()的部分。

進入主題之前,我們要注意的是,test_loader是固定的1000筆資料直接使用(沒epoch),所以不download參數。因此,它並不帶有batch index 這欄位。

首先,先宣告這是validation的作業,而後將一些變數清0。

    model.eval()
    test_loss = 0
    correct = 0

由於我們在這個階段,目的是測試model的準確性,也就利用model做一種推估,再比較和實際值的差異,進而得到其準確性。所以,我們不需要 backword and optimizer來優化 weights。因此,我們宣告暫時不需要gradient。

with torch.no_grad():

接下來,每次讀一筆資料。讀入的資料,放入device中。進而餵入model得到預測結果output。

            data, target = data.to(device), target.to(device)
            output = model(data)

接著,我們計算 output 和 target 之間的 loss,累計之。同時,也計算推論正確的次數 correct,亦累計之。

            # sum up batch loss
            test_loss += F.nll_loss(output, target, reduction='sum').item()
	 #
            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)
	 #
            correct += pred.eq(target.view_as(pred)).sum().item()

Pred and correct大致說明一下。假設我們預測手寫數字為5,那麼output中機率最高的應為output[5],所以pred=5。若預測的答案和target (5)相同,correct就累加1。

以上章節,講完本機的部分。
接下來會談及雲端的部分。


上一篇
模型的內容07 train()
下一篇
NNI如何搬到雲端上玩?
系列文
新手一起來Azure上玩 NNI (auto-ML的一種)30

尚未有邦友留言

立即登入留言