延續上一篇對手語資料的探索,這次我們將打造一個基於 CNN 的靜態手勢分類模型。將帶你從資料前處理、模型架構設計、訓練策略(包含 Early Stopping)到訓練成果視覺化,一步步構建出能夠辨識 24 種英文字母的手勢辨識模型,為未來延伸至即時手勢辨識系統奠定穩固基礎。
使用的是 Kaggle 上公開的 Sign Language MNIST 資料集,其特點如下:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToPILImage
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import pandas as pd
import numpy as np
df = pd.read_csv("sign_mnist_train/sign_mnist_train.csv")
labels = df['label'].values
images = df.drop('label', axis=1).values.reshape(-1, 28, 28)
過濾掉 J(label = 9):
mask = labels != 9
images = images[mask]
labels = labels[mask]
重新編碼為連續數字:
le = LabelEncoder()
labels = le.fit_transform(labels)
切分訓練集 / 驗證集:
train_images, val_images, train_labels, val_labels = train_test_split(
images, labels, test_size=0.2, stratify=labels, random_state=42
)
Dataset 定義:
class SignLanguageDataset(Dataset):
def __init__(self, images, labels, transform=None):
self.images = images
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
image = self.images[idx].astype(np.uint8)
if self.transform:
image = self.transform(image)
else:
image = torch.tensor(image / 255.0).unsqueeze(0).float()
label = self.labels[idx]
return image, label
train_transform = transforms.Compose([
ToPILImage(),
transforms.RandomRotation(10),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
val_transform = transforms.Compose([
ToPILImage(),
transforms.ToTensor()
])
train_dataset = SignLanguageDataset(train_images, train_labels, transform=train_transform)
val_dataset = SignLanguageDataset(val_images, val_labels, transform=val_transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
class SimpleCNN(nn.Module):
def __init__(self, num_classes=24):
super(SimpleCNN, self).__init__()
self.cnn = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
nn.MaxPool2d(2),
)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 3 * 3, 256), nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.cnn(x)
x = self.fc(x)
return x
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
EarlyStopping 設定:
patience = 5
best_val_loss = float('inf')
counter = 0
num_epochs = 50
train_losses, val_losses = [], []
from tqdm import tqdm
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
avg_train_loss = running_loss / len(train_loader)
train_losses.append(avg_train_loss)
驗證:
model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
avg_val_loss = val_loss / len(val_loader)
val_losses.append(avg_val_loss)
val_acc = correct / total
print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.4f}")
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
counter = 0
torch.save(model.state_dict(), "model.pth")
else:
counter += 1
if counter >= patience:
print("Early stopping triggered!")
break
import matplotlib.pyplot as plt
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title("Loss Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.show()
這篇我們完成了 靜態手語辨識的基礎模型訓練流程,從資料預處理、增強、CNN 設計到 EarlyStopping 的策略設計。這個模型已經具備不錯的分類能力,是邁向即時手勢辨識的第一步!
🔜 下一篇文章預告:
把模型從「圖像」推進到「鏡頭」——整合 OpenCV 與 MediaPipe,實現即時的手勢辨識互動,讓你的模型真正看見並理解你!