本文會探討如何使用 Weights and Biases (wandb) 來解決模型訓練中的常見困境,並展示一個完整的例子,幫助你輕鬆追蹤實驗、管理模型版本及參數變動。
假設你正在開發一個圖像分類模型,使用卷積神經網絡 (CNN) 進行訓練。你已經花了數天時間不斷進行參數調整,但每次實驗結束後,結果無法令人滿意。你試圖記錄下每次實驗的超參數設置、訓練時間、準確率等資訊,但手動管理變得愈來愈複雜,數據不斷堆積,讓你難以找到具體的性能改進。
這時候,wandb 提供了即時的實驗追蹤、參數記錄、模型可視化,並自動儲存結果,幫助你更快地找到最佳模型。
首先,安裝 wandb:
pip install wandb
登入 wandb :
wandb login
接著,在程式碼中初始化 wandb:
import wandb
# 初始化 wandb,設定專案名稱
wandb.init(project="image-classification", config={
"learning_rate": 0.001,
"epochs": 10,
"batch_size": 32
})
接下來,設置一個簡單的 CNN 模型來訓練 CIFAR-10 資料集,並使用 wandb 追蹤訓練過程中的超參數、損失函數與準確度。
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# CIFAR-10 資料集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=wandb.config.batch_size, shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=wandb.config.batch_size, shuffle=False)
# 定義 CNN 模型
class CNN(nn.Module):
# 請自行定義
pass
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=wandb.config.learning_rate)
在每個 epoch 中,記錄損失函數和準確度到 wandb。
# 訓練過程
def train_model():
for epoch in range(wandb.config.epochs):
running_loss = 0.0
correct = 0
total = 0
for i, (inputs, labels) in enumerate(trainloader):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 計算損失和準確度
running_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# 記錄每個 epoch 的結果到 wandb
wandb.log({
"epoch": epoch + 1,
"loss": running_loss / len(trainloader),
"accuracy": correct / total
})
print(f'Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)}, Accuracy: {correct / total}')
train_model()
每次訓練結束後,wandb 會自動將結果上傳至雲端,你可以在 wandb 的儀表板上查看每次實驗的詳細數據,包括損失函數、準確度、參數變動,並與其他實驗結果進行比較。
# 儲存模型
torch.save(model.state_dict(), "cnn_model.pth")
wandb.save("cnn_model.pth")
在多次訓練實驗中,可以隨時調整超參數並重新訓練模型。例如,嘗試不同的學習率,wandb 會自動幫你記錄每次實驗的變更。
# 調整學習率後重新訓練
wandb.config.update({"learning_rate": 0.0001})
train_model()
wandb 就像一位超級助理,幫你記錄每次實驗的點點滴滴,包括你調整的參數、訓練過程中的損失和準確度變化,甚至連模型本身都能幫你保存下來! 你可以隨時查看這些記錄,就像翻閱實驗筆記一樣,清楚地知道每次嘗試的結果。wandb 還能將這些數據轉化成圖表,讓你一目了然地看到模型的訓練過程和性能變化。 你可以輕鬆比較不同參數設置的效果,找到讓模型表現最佳的秘方!有了 wandb,模型訓練不再是一團亂麻,你就能專注於提升模型的準確度,讓它成為辨識貓貓狗狗的超級高手!