iT邦幫忙

2025 iThome 鐵人賽

DAY 12
0
Cloud Native

新興k8s工作流flyte與MLOps。系列 第 12

Day 12: pytorch workflow範例

  • 分享至 

  • xImage
  •  
import os
import tempfile

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose

# Model, Loss, Optimizer
model = resnet18(num_classes=10)
model.conv1 = torch.nn.Conv2d(
    1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
model.to("cpu")
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)

# Data
transform = Compose([ToTensor(), Normalize((0.28604,), (0.32025,))])
train_data = FashionMNIST(root='./data', train=True, download=False, transform=transform)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)

# Training
for epoch in range(1):
    for images, labels in train_loader:
        images, labels = images.to("cpu"), labels.to("cpu")
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    metrics = {"loss": loss.item(), "epoch": epoch}
    checkpoint_dir = tempfile.mkdtemp()
    checkpoint_path = os.path.join(checkpoint_dir, "model.pt")
    torch.save(model.state_dict(), checkpoint_path)
    print(metrics)

猜分成

import flytekit as fl

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose

image_spec = fl.ImageSpec(
    name="pytorch-resnet", requirements="uv.lock", registry="localhost:30000"
)


@fl.task(container_image=image_spec)
def download_data(data_path: str) -> fl.FlyteDirectory:
    FashionMNIST(root=data_path, train=True, download=True, transform=None)
    return fl.FlyteDirectory(path=data_path)

@fl.task(container_image=image_spec)
def train(epoch: int, lr: float, fd: fl.FlyteDirectory) -> fl.FlyteFile:
    """
    1. Data to dataloader
    2. train model and save the weight
    """
    fd.download()
    transform = Compose([ToTensor(), Normalize((0.28604,), (0.32025,))])
    train_data = FashionMNIST(
        root=fd.path, train=True, download=False, transform=transform
    )
    train_loader = DataLoader(train_data, batch_size=128, shuffle=True)

    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    model.to("cpu")
    criterion = CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=lr)
    wieght_path = str(fl.current_context().working_directory) + "/model.pt"
    for current_epoch in range(epoch):
        for images, labels in train_loader:
            images, labels = images.to("cpu"), labels.to("cpu")
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        metrics = {"loss": loss.item(), "epoch": current_epoch}
        print(metrics)
    torch.save(model.state_dict(), checkpoint_path)
    return fl.FlyteFile(path=wieght_path)


@fl.workflow
def train_resnet_wf(epoch: int = 5, lr: float = 0.001, data_path: str = "/tmp/data") -> fl.FlyteFile:
    data = download_data(data_path)
    weight_file = train(epoch, lr, data)
    return weight_file

上一篇
Day 10: 工作流獲取k8s secret
下一篇
Day 13: flyte-binary部屬 part 1
系列文
新興k8s工作流flyte與MLOps。13
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言