iT邦幫忙

2024 iThome 鐵人賽

DAY 26
0
自我挑戰組

30 天程式學習筆記:我的自學成長之路系列 第 26

[DAY 26]告別手動記錄:Weights & Biases 實戰指南

  • 分享至 

  • xImage
  •  

本文會探討如何使用 Weights and Biases (wandb) 來解決模型訓練中的常見困境,並展示一個完整的例子,幫助你輕鬆追蹤實驗、管理模型版本及參數變動。

假設你正在開發一個圖像分類模型,使用卷積神經網絡 (CNN) 進行訓練。你已經花了數天時間不斷進行參數調整,但每次實驗結束後,結果無法令人滿意。你試圖記錄下每次實驗的超參數設置、訓練時間、準確率等資訊,但手動管理變得愈來愈複雜,數據不斷堆積,讓你難以找到具體的性能改進。

這時候,wandb 提供了即時的實驗追蹤、參數記錄、模型可視化,並自動儲存結果,幫助你更快地找到最佳模型。

1. 安裝和設定 wandb

首先,安裝 wandb:

pip install wandb

登入 wandb :

wandb login

接著,在程式碼中初始化 wandb:

import wandb

# 初始化 wandb,設定專案名稱
wandb.init(project="image-classification", config={
    "learning_rate": 0.001,
    "epochs": 10,
    "batch_size": 32
})

2. 訓練一個簡單的 CNN 模型

接下來,設置一個簡單的 CNN 模型來訓練 CIFAR-10 資料集,並使用 wandb 追蹤訓練過程中的超參數、損失函數與準確度。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# CIFAR-10 資料集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=wandb.config.batch_size, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=wandb.config.batch_size, shuffle=False)

# 定義 CNN 模型
class CNN(nn.Module):
	# 請自行定義
    pass

model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=wandb.config.learning_rate)

3. 訓練模型並記錄到 wandb

在每個 epoch 中,記錄損失函數和準確度到 wandb。

# 訓練過程
def train_model():
    for epoch in range(wandb.config.epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        for i, (inputs, labels) in enumerate(trainloader):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # 計算損失和準確度
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        # 記錄每個 epoch 的結果到 wandb
        wandb.log({
            "epoch": epoch + 1,
            "loss": running_loss / len(trainloader),
            "accuracy": correct / total
        })

        print(f'Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)}, Accuracy: {correct / total}')

train_model()

4. 使用 wandb 追蹤和分析實驗結果

每次訓練結束後,wandb 會自動將結果上傳至雲端,你可以在 wandb 的儀表板上查看每次實驗的詳細數據,包括損失函數、準確度、參數變動,並與其他實驗結果進行比較。

# 儲存模型
torch.save(model.state_dict(), "cnn_model.pth")
wandb.save("cnn_model.pth")

5. 使用 wandb 進行版本控制與參數對比

在多次訓練實驗中,可以隨時調整超參數並重新訓練模型。例如,嘗試不同的學習率,wandb 會自動幫你記錄每次實驗的變更。

# 調整學習率後重新訓練
wandb.config.update({"learning_rate": 0.0001})
train_model()

使用 WandB 讓模型訓練不再像大海撈針!

wandb 就像一位超級助理,幫你記錄每次實驗的點點滴滴,包括你調整的參數、訓練過程中的損失和準確度變化,甚至連模型本身都能幫你保存下來! 你可以隨時查看這些記錄,就像翻閱實驗筆記一樣,清楚地知道每次嘗試的結果。wandb 還能將這些數據轉化成圖表,讓你一目了然地看到模型的訓練過程和性能變化。 你可以輕鬆比較不同參數設置的效果,找到讓模型表現最佳的秘方!有了 wandb,模型訓練不再是一團亂麻,你就能專注於提升模型的準確度,讓它成為辨識貓貓狗狗的超級高手!


上一篇
[DAY 25]模型複雜,數據混亂追蹤難:Weights & Biases 實現高效實驗管理!
下一篇
[DAY 27]告別手動調參:AutoML 打造高效機器學習流程
系列文
30 天程式學習筆記:我的自學成長之路30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言