模型表現,三分看網路、七分看資料。就算你用的是 ResNet 或 ViT,如果原始資料夾裡混著空白圖、壞圖、重覆檔,加上類別不均衡,訓練再久也事倍功半。今天這篇延續上一章的實作,直接對 train/、test/ 類別資料夾做清理與增強,接著進行分層切分(train/val)、ResNet18 微調、MixUp、漸進式解凍、Flip-TTA 評估。
import os, shutil, random, json, hashlib
from glob import glob
from typing import Tuple, List
import numpy as np
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms as T, models
from torchvision.models import ResNet18_Weights
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
下方清理步驟會把 FER2013 的 train/、test/ 逐類別掃過,轉灰階、移除空白/壞圖,輸出到 FER2013_clean。若已清理過則跳過。
def is_blank(img, low=2, high=253):
"""粗略判斷近乎全黑/全白;img 為灰階 PIL"""
arr = np.array(img, dtype=np.uint8)
m = float(arr.mean())
return (m < low) or (m > high)
def clean_split_folder(raw_dir, clean_dir):
os.makedirs(clean_dir, exist_ok=True)
kept, skipped = 0, 0
for cls in sorted(os.listdir(raw_dir)):
src_cls = os.path.join(raw_dir, cls)
dst_cls = os.path.join(clean_dir, cls)
if not os.path.isdir(src_cls):
continue
os.makedirs(dst_cls, exist_ok=True)
for fp in glob(os.path.join(src_cls, "*")):
try:
with Image.open(fp) as im:
im = im.convert("L") # 灰階
if is_blank(im): # 過度黑/白 → 丟
skipped += 1
continue
base = os.path.splitext(os.path.basename(fp))[0]
im.save(os.path.join(dst_cls, f"{base}.png")) # 統一存 png
kept += 1
except Exception:
skipped += 1
return kept, skipped
def ensure_cleaned(raw_root, clean_root):
for split in ["train", "test"]:
src = os.path.join(raw_root, split)
dst = os.path.join(clean_root, split)
kept, skipped = clean_split_folder(src, dst)
print(f"[CLEAN] {src} → {dst}")
print(f"[CLEAN] kept={kept}, skipped={skipped}")
訓練時使用保守增強(避免臉部結構失真):輕微裁切、翻轉、旋轉、位移/縮放、亮度對比、RandomErasing。驗證/測試用固定的 Resize + CenterCrop + Normalize。
IMNET_MEAN = (0.485, 0.456, 0.406)
IMNET_STD = (0.229, 0.224, 0.225)
train_tf = T.Compose([
T.Grayscale(3),
T.RandomResizedCrop(224, scale=(0.92, 1.0)),
T.RandomHorizontalFlip(0.5),
T.RandomRotation(7, fill=0),
T.RandomAffine(degrees=0, translate=(0.02, 0.02), scale=(0.95, 1.05)),
T.ColorJitter(brightness=0.12, contrast=0.12),
T.ToTensor(),
T.Normalize(IMNET_MEAN, IMNET_STD),
T.RandomErasing(p=0.25, scale=(0.02, 0.08), ratio=(0.3, 3.3)),
])
eval_tf = T.Compose([
T.Grayscale(3),
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(IMNET_MEAN, IMNET_STD),
])
用 ImageFolder 載入清理後的 train/,再以標籤做 stratified split 切成 train/val;測試集直接使用清理後的 test/。
TRAIN_DIR = "/Users/emily/FER2013_clean/train"
TEST_DIR = "/Users/emily/FER2013_clean/test"
BATCH_TRAIN, BATCH_EVAL, SEED = 128, 256, 42
train_full = datasets.ImageFolder(TRAIN_DIR, transform=train_tf)
test_set = datasets.ImageFolder(TEST_DIR, transform=eval_tf)
class_names = train_full.classes
targets = [lbl for _, lbl in train_full.samples]
idx_all = np.arange(len(targets))
train_idx, val_idx = train_test_split(
idx_all, test_size=0.2, stratify=targets, random_state=SEED
)
train_set = Subset(datasets.ImageFolder(TRAIN_DIR, transform=train_tf), train_idx)
val_set = Subset(datasets.ImageFolder(TRAIN_DIR, transform=eval_tf), val_idx)
device = (torch.device("cuda") if torch.cuda.is_available()
else torch.device("mps") if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
else torch.device("cpu"))
PIN = device.type == "cuda"
NUM_WORKERS = min(4, os.cpu_count() or 2)
common = dict(num_workers=NUM_WORKERS, pin_memory=PIN, persistent_workers=NUM_WORKERS > 0)
train_loader = DataLoader(train_set, batch_size=BATCH_TRAIN, shuffle=True, drop_last=True, prefetch_factor=2, **common)
val_loader = DataLoader(val_set, batch_size=BATCH_EVAL, shuffle=False, **common)
test_loader = DataLoader(test_set, batch_size=BATCH_EVAL, shuffle=False, **common)
# 模型
weights = ResNet18_Weights.DEFAULT
model = models.resnet18(weights=weights)
model.fc = nn.Linear(model.fc.in_features, len(class_names))
model = model.to(device)
# 先凍 backbone
for name, p in model.named_parameters():
if not name.startswith("fc."):
p.requires_grad = False
backbone_params = [p for n,p in model.named_parameters() if (not n.startswith("fc."))]
head_params = [p for n,p in model.named_parameters() if n.startswith("fc.")]
optimizer = optim.AdamW([{"params": head_params, "lr": 3e-4}], weight_decay=1e-4)
criterion = nn.CrossEntropyLoss(label_smoothing=0.05)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5,
patience=2, threshold=1e-4, min_lr=1e-6)
# MixUp
def mixup_batch(x, y, alpha=0.2):
lam = np.random.beta(alpha, alpha)
idx = torch.randperm(x.size(0), device=x.device)
x_mix = lam * x + (1.0 - lam) * x[idx]
y_a, y_b = y, y[idx]
return x_mix, y_a, y_b, lam
EPOCHS, FREEZE_EPOCHS = 30, 3
best_val_acc, patience_es, no_improve = 0.0, 5, 0
for epoch in range(1, EPOCHS+1):
if epoch == FREEZE_EPOCHS:
for p in backbone_params:
p.requires_grad = True
optimizer = optim.AdamW([
{"params": backbone_params, "lr": 1e-4},
{"params": head_params, "lr": 3e-4},
], weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5,
patience=2, threshold=1e-4, min_lr=1e-6)
print(">>> Unfrozen backbone: full fine-tuning")
# --- Train ---
model.train()
run = 0.0
for xb, yb in train_loader:
xb, yb = xb.to(device), yb.to(device)
optimizer.zero_grad()
xb_mix, ya, yb2, lam = mixup_batch(xb, yb, alpha=0.2)
out = model(xb_mix)
loss = lam * criterion(out, ya) + (1 - lam) * criterion(out, yb2)
loss.backward(); optimizer.step()
run += loss.item()
train_loss = run / max(1, len(train_loader))
# --- Val ---
model.eval()
vloss, correct, total = 0.0, 0, 0
with torch.no_grad():
for xb, yb in val_loader:
xb, yb = xb.to(device), yb.to(device)
out = model(xb)
vloss += criterion(out, yb).item()
pred = out.argmax(1)
correct += (pred == yb).sum().item()
total += yb.size(0)
val_loss = vloss / max(1, len(val_loader))
val_acc = correct / max(1, total)
scheduler.step(val_loss)
print(f"Epoch {epoch:02d}/{EPOCHS} TrainLoss:{train_loss:.4f} ValLoss:{val_loss:.4f} ValAcc:{val_acc:.4f}")
if val_acc > best_val_acc:
best_val_acc, no_improve = val_acc, 0
torch.save(model.state_dict(), "outputs/fer_resnet18_best.pth")
else:
no_improve += 1
if no_improve >= patience_es:
print("Early stopping triggered."); break
# 載入最佳權重
model.load_state_dict(torch.load("outputs/fer_resnet18_best.pth", map_location=device))
model.eval()
def predict_tta(xb):
with torch.no_grad():
l1 = model(xb)
l2 = model(torch.flip(xb, dims=[3]))
return (l1 + l2) / 2
all_y, all_p = [], []
with torch.no_grad():
for xb, yb in test_loader:
xb = xb.to(device)
preds = predict_tta(xb).argmax(1).cpu().numpy().tolist()
all_p += preds
all_y += yb.numpy().tolist()
print("\n=== Classification Report (Test, Flip-TTA) ===")
print(classification_report(all_y, all_p, target_names=class_names, digits=4))
cm = confusion_matrix(all_y, all_p)
try:
import seaborn as sns
plt.figure(figsize=(7,6))
sns.heatmap(cm, annot=False, cmap="Blues",
xticklabels=class_names, yticklabels=class_names)
except Exception:
plt.figure(figsize=(7,6))
plt.imshow(cm, cmap="Blues")
plt.xticks(range(len(class_names)), class_names, rotation=45, ha="right")
plt.yticks(range(len(class_names)), class_names)
plt.title("FER2013 Confusion Matrix (ResNet18 + MixUp + Flip-TTA)")
plt.xlabel("Predicted"); plt.ylabel("True")
plt.tight_layout(); plt.savefig("outputs/confusion_matrix.png", dpi=150); plt.close()
當資料夾乾淨(壞圖/空白圖剔除)、一致(通道數與正規化對齊預訓練)、合理增強(只在訓練集)、分層切分且可重現(固定種子、保存最佳權重與輸出評估),你就能把 ResNet18 這類的預訓練模型發揮到位。上面這套流程已在你手上的 FER2013 圖片版跑通,成績也更穩定;接下來要衝更高,只需在此基礎上加強資料品質或疊進階技巧(如 TenCrop-TTA、EMA、SWA、CutMix)。如果你想,我也可以幫你把這些「加速包」納進同一份腳本。