iT邦幫忙

2021 iThome 鐵人賽

DAY 14
0

上一章節研究完class net(…),這一章節我們繼續研究 def main(args)這部分。

def main(args):
    # define data directory and device (CPU or GPU)
    use_cuda = not args['no_cuda'] and torch.cuda.is_available()
    torch.manual_seed(args['seed'])
    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    data_dir = args['data_dir']

    # --- data loader
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args['batch_size'], shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=1000, shuffle=True, **kwargs)
    # --- end  data loader

    hidden_size = args['hidden_size']
    model = Net(hidden_size=hidden_size).to(device)
    optimizer = optim.SGD(model.parameters(), lr=args['lr'],
                          momentum=args['momentum'])
    for epoch in range(1, args['epochs'] + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test_acc = test(args, model, device, test_loader)
        # report intermediate result
        nni.report_intermediate_result(test_acc)
        logger.debug('test accuracy %g', test_acc)
        logger.debug('Pipe send intermediate result done.')
    # report final result
    nni.report_final_result(test_acc)
    logger.debug('Final result is %g', test_acc)
    logger.debug('Send final result done.')

一開始的部分,主要是定義 data directory and device for CPU or GPU。

接下來我們看一下 data loader的部分。 data loader可以幫我們整理轉換資料外(ToTensor()),還可以依照我們的需要,一批批的吐出來。例如train_loader裡的 batch_size。說明一下,這裡的test_loader,是訓練模型時,用來做validation用的。個人比較喜歡用valid_loader一詞。

資料有時候很分散,此時會影響訓練及驗證速度(計算速度、判別速度),所以需要正規化。 transforms.Normalize((0.1307,), (0.3081,)),由於資料是黑白的,RGB channel 只有1個,所以只有1個數字。Mean=(0.1307,),Std= (0.3081,)。至於裡面的數字為何為此,我也不知。也許可以從data.describe()的統計表中得知一二吧!

另外,shuffle=True,主要是讓每批資料,都能很平均的取樣,以免資料取樣產生偏頗,讓模型導致無效!

初始化model為神經網路 class Net時,要設定其隱藏層的大小,hidden_size。即

model = Net(hidden_size=hidden_size).to(device)。

請注意,model 是放在device中執行,所以 training and validation 時,loader過來的資料也得放於device中,否則 training and validation 時,會找不到資料。可自行實驗。

下一回,我們繼續往下說明 。


上一篇
模型的內容03 Class Net
下一篇
模型的內容05 def main()
系列文
新手一起來Azure上玩 NNI (auto-ML的一種)30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言