這個章節,我們將談到 test()的部分。
進入主題之前,我們要注意的是,test_loader是固定的1000筆資料直接使用(沒epoch),所以不download參數。因此,它並不帶有batch index 這欄位。
首先,先宣告這是validation的作業,而後將一些變數清0。
model.eval()
test_loss = 0
correct = 0
由於我們在這個階段,目的是測試model的準確性,也就利用model做一種推估,再比較和實際值的差異,進而得到其準確性。所以,我們不需要 backword and optimizer來優化 weights。因此,我們宣告暫時不需要gradient。
with torch.no_grad():
接下來,每次讀一筆資料。讀入的資料,放入device中。進而餵入model得到預測結果output。
data, target = data.to(device), target.to(device)
output = model(data)
接著,我們計算 output 和 target 之間的 loss,累計之。同時,也計算推論正確的次數 correct,亦累計之。
# sum up batch loss
test_loss += F.nll_loss(output, target, reduction='sum').item()
#
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
#
correct += pred.eq(target.view_as(pred)).sum().item()
Pred and correct大致說明一下。假設我們預測手寫數字為5,那麼output中機率最高的應為output[5],所以pred=5。若預測的答案和target (5)相同,correct就累加1。
以上章節,講完本機的部分。
接下來會談及雲端的部分。