iT邦幫忙

2021 iThome 鐵人賽

DAY 17
0

這章節,我們將說明 train()的細部。
程式部分如下:

def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if (args['batch_num'] is not None) and batch_idx >= args['batch_num']:
            break
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args['log_interval'] == 0:
            logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

首先,傳進來的函數參數有:args, model, device, train_loader, optimizer, epoch。這些參數,在前面章節已經敘述過了,在此省略。

首先,我們要宣告model要做的事為訓練。

 model.train()

接下來,for loop 中,train_loader會一筆一筆的將資料匯進來,每筆資料的格式為 [batch_idx, (data, target)]。每批資料,都會有自己的索引,batch_index。If 的部分,判斷資料是否還在,不存在則跳離。
接著,我們將一筆train data、target(label)放入device中。(因為,model也在device中。大家要再一起,才能執行。)

 data, target = data.to(device), target.to(device)

然後,記得將optimizer歸零,否則會一直累加!

 optimizer.zero_grad()

再來將data餵入model中,得到 output。

  output = model(data)

接著計算output and target的差距,因為資料為 HxWxC是2D資料,適合用 nll_loss來計算 loss。

  loss = F.nll_loss(output, target)

而後回去計算神經網路叢裡的weights and bias。

 loss.backward()

接下來,根據learn rate and momentum,進行優化計算。(就是weights and bias的調整,讓loss得到最小值)
優化完畢後,將新的weights and bias取代原有的。

 optimizer.step()

再來的if,主要是在用於判斷執行幾圈,才進行logger的動作。得到的結果,大致類似如此:

[2021-09-18 15:34:21] INFO (mnist_AutoML/MainThread) Train Epoch: 10 [0/60000 (0%)]     Loss: 0.003544

下一個章節,我們將談到 test()


上一篇
模型的內容06 def main()
下一篇
模型的內容08 test()
系列文
新手一起來Azure上玩 NNI (auto-ML的一種)30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言