上一章節研究完class net(…),這一章節我們繼續研究 def main(args)這部分。
def main(args):
# define data directory and device (CPU or GPU)
use_cuda = not args['no_cuda'] and torch.cuda.is_available()
torch.manual_seed(args['seed'])
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
data_dir = args['data_dir']
# --- data loader
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dir, train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args['batch_size'], shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dir, train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=1000, shuffle=True, **kwargs)
# --- end data loader
hidden_size = args['hidden_size']
model = Net(hidden_size=hidden_size).to(device)
optimizer = optim.SGD(model.parameters(), lr=args['lr'],
momentum=args['momentum'])
for epoch in range(1, args['epochs'] + 1):
train(args, model, device, train_loader, optimizer, epoch)
test_acc = test(args, model, device, test_loader)
# report intermediate result
nni.report_intermediate_result(test_acc)
logger.debug('test accuracy %g', test_acc)
logger.debug('Pipe send intermediate result done.')
# report final result
nni.report_final_result(test_acc)
logger.debug('Final result is %g', test_acc)
logger.debug('Send final result done.')
一開始的部分,主要是定義 data directory and device for CPU or GPU。
接下來我們看一下 data loader的部分。 data loader可以幫我們整理轉換資料外(ToTensor()),還可以依照我們的需要,一批批的吐出來。例如train_loader裡的 batch_size。說明一下,這裡的test_loader,是訓練模型時,用來做validation用的。個人比較喜歡用valid_loader一詞。
資料有時候很分散,此時會影響訓練及驗證速度(計算速度、判別速度),所以需要正規化。 transforms.Normalize((0.1307,), (0.3081,)),由於資料是黑白的,RGB channel 只有1個,所以只有1個數字。Mean=(0.1307,),Std= (0.3081,)。至於裡面的數字為何為此,我也不知。也許可以從data.describe()的統計表中得知一二吧!
另外,shuffle=True,主要是讓每批資料,都能很平均的取樣,以免資料取樣產生偏頗,讓模型導致無效!
初始化model為神經網路 class Net時,要設定其隱藏層的大小,hidden_size。即
model = Net(hidden_size=hidden_size).to(device)。
請注意,model 是放在device中執行,所以 training and validation 時,loader過來的資料也得放於device中,否則 training and validation 時,會找不到資料。可自行實驗。
下一回,我們繼續往下說明 。