如果我們手上只有數百張貓狗的照片,是沒辦法像前幾天一樣訓練出像 ResNet 這種好的分類器,反而會遇到過擬合的問題。在這種缺乏數據的情況下,我們能用遷移學習 (transfer learning) 跟資料增強 (data augmentation) 兩大方法來解決這個問題。
遷移學習的核心思想為:將從一個任務中學到的知識,應用到另一個相關的任務上。例如,我們從 ImageNet 上訓練好的大型 CNN 模型,他的特徵提取能力是能挪用到我們前面的貓狗分類任務上。我們沒必要從 0 開始再去訓練一個模型。
實務上,遷移學習有兩種策略
作為特徵提取器 (feature extractor):當我們的目標數據集非常小,且與源數據集(如 ImageNet)比較相似時,我們可以載入一個預訓練好的 CNN 模型(例如 MobileNetV2),去掉它最頂部的全連接分類層(因為它原本是用來分 ImageNet 的1000類的)。然後,我們「凍結 (freeze)」前面所有卷積層的權重,不讓它們在訓練中被更新。我們只在這些被凍結的卷積層之上,添加一個我們自己的、新的、小型的分類器(例如一個新的全連接層),並只訓練這個新的分類器。
微調 (fine-tuning):當我們的目標數據集較大,或者與源數據集差異較大時,我們同樣載入預訓練模型並替換掉頂部分類層。但這次,我們不再完全凍結前面的卷積層。我們會讓整個網路都參與訓練,但通常會為淺層的卷積層設置一個非常小的學習率 (learning rate),而為我們自己新增的分類層設置一個較大的學習率。
即使有了遷移學習,如果我們的訓練數據只有幾百張,模型還是很容易發生過擬合。為了解決這個問題,我們需要資料增強,對現有的訓練圖片進行一系列隨機的、輕微的變換,來「無中生有」訓練樣本。
常見的技巧有
隨機水平翻轉:一張貓的圖片,水平翻轉後,它依然是一張貓。
隨機旋轉:在一個小角度範圍內(如 -15 到 +15 度)隨機旋轉。
隨機裁切與縮放:隨機地從原圖中裁切出一塊區域,並將其縮放到目標尺寸。這模擬了物體在圖片中位置和大小的變化。
顏色抖動:隨機地改變圖片的亮度、對比度、飽和度和色調。
資料增強豐富了訓練數據的多樣性,強迫模型去學習物體更本質的特徵,而不是記住一些偶然的細節,顯著提升了模型的泛化能力。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
import os
import urllib.request
import zipfile
# --- 0. 下載並解壓資料集 ---
def download_hymenoptera_data():
data_dir = 'hymenoptera_data'
if not os.path.exists(data_dir):
print("正在下載 Hymenoptera 資料集...")
url = 'https://download.pytorch.org/tutorial/hymenoptera_data.zip'
urllib.request.urlretrieve(url, 'hymenoptera_data.zip')
print("正在解壓縮...")
with zipfile.ZipFile('hymenoptera_data.zip', 'r') as zip_ref:
zip_ref.extractall('.')
os.remove('hymenoptera_data.zip')
print("資料集準備完成!")
else:
print("資料集已存在")
if __name__ == '__main__':
# 下載資料集
download_hymenoptera_data()
# --- 1. 設定超參數與設備 ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
learning_rate = 0.001
batch_size = 16
num_epochs = 15
# --- 2. 資料增強與數據載入 ---
# 為訓練集和驗證集定義不同的 transform
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224), # 隨機裁切並縮放
transforms.RandomHorizontalFlip(), # 隨機水平翻轉
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
# 下載並準備數據集 (PyTorch 會自動下載)
data_dir = 'hymenoptera_data'
# 使用 ImageFolder,它會自動從資料夾名稱 ('ants', 'bees') 推斷標籤
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=0)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
print(f"資料集類別: {class_names}")
print(f"訓練集大小: {dataset_sizes['train']}, 驗證集大小: {dataset_sizes['val']}")
# --- 3. 載入預訓練模型並修改分類層 ---
# 使用 ResNet-18
model = models.resnet18(pretrained=True)
# 凍結所有預訓練層的權重
for param in model.parameters():
param.requires_grad = False
# 獲取最後一個全連接層的輸入特徵數
num_ftrs = model.fc.in_features
# 替換掉原本的分類層,換成我們自己的 (輸出為2,因為只有螞蟻和蜜蜂)
model.fc = nn.Linear(num_ftrs, len(class_names))
model = model.to(device)
print("\n模型結構已修改完成!")
# --- 4. 定義損失函數與優化器 ---
criterion = nn.CrossEntropyLoss()
# 注意:我們只將需要更新的參數 (新的 fc 層) 傳給優化器
optimizer = optim.SGD(model.fc.parameters(), lr=learning_rate, momentum=0.9)
# --- 5. 訓練與評估模型 ---
print("開始訓練...")
for epoch in range(num_epochs):
print(f'\nEpoch {epoch+1}/{num_epochs}')
print('-' * 10)
# 每個 epoch 都有一個訓練和驗證階段
for phase in ['train', 'val']:
if phase == 'train':
model.train() # 設為訓練模式
else:
model.eval() # 設為評估模式
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
# 只在訓練階段計算梯度
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# 只在訓練階段進行反向傳播和最佳化
if phase == 'train':
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
print("\n訓練完成!")
...
Epoch 15/15
----------
train Loss: 0.2320 Acc: 0.8934
val Loss: 0.1735 Acc: 0.9608