Day 17 煉丹爐開始煉丹啦 - 訓練神經網路 介紹整個訓練的流程,但是準確率大概在 75% 左右就到極限了,因此今天加入了卷積神經網路。
定義網路時會寫兩個 module,分別卷積和全連階層,整個運作的流程是 Conv (卷積) → Flatten (平坦) → FC (Fully Connect, 全連接層),程式碼如下:
class CustomConvNeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.cnn_module = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
nn.MaxPool2d(2, 2),
nn.ReLU(),
)
self.fc_modeul = nn.Sequential(
nn.Linear(16 * 53 * 53, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, 2)
)
def forward(self, x):
x = self.cnn_module(x)
x = self.flatten(x)
x = self.fc_modeul(x)
return x
cnn_model = CustomConvNeuralNetwork().to(device)
print(cnn_model)
訓練和測試的程式碼與 Day17 相同:
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
# 將資料讀取到GPU中
X, y = X.to(device), y.to(device)
# 運算出結果並計算loss
pred = model(X)
loss = loss_fn(pred, y)
# 反向傳播
loss.backward()
optimizer.step()
optimizer.zero_grad()
if batch % 100 == 0:
loss, current = loss.item(), (batch + 1) * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test_loop(dataloader, model, loss_fn):
model.eval()
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
# 驗證或測試時記得加入 torch.no_grad() 讓神經網路不要更新
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
最後我們一樣訓練十個 epochs,看看結果如何:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(cnn_model.parameters(), lr=learning_rate)
epochs = 10
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train_loop(train_dataloader, cnn_model, loss_fn, optimizer)
test_loop(val_dataloader, cnn_model, loss_fn)
print("Done!")
最後可以看到準確率大約為 80%:
加入了卷積的運算雖然有增加一些準確率,但應該還能做得更好,接下來會介紹卷積的運作和嘗試訓練出更準的模型。