到目前為止,我們所學的所有監督式模型,都是建立在大量高品質的人工標註數據上。這個過程耗時耗力,也限制了模型的應用規模。
我們是否也能讓模型像人類一樣,單純透過觀察、互動與推理,來自己學習到有意義的視覺特徵呢?
自監督學習 (self-supervised learning) 的核心思想是:從數據本身,自動地創造出偽標籤,然後像監督式學習一樣進行訓練。模型需要完成一個由我們設計的「代理任務」(pretext task),在完成任務的過程中,它會學習到高品質、關於數據的特徵表徵 (feature reprentation)。
早期常見的代理任務有
圖片上色
圖片修復
拼圖
預測旋轉
雖然這些代理任務有效,但容易造成學到的特徵與任務本身關聯太大,對比學習 (contrastive learning) 也在這個情況下誕生。
對比學習的核心思想為:將相似的樣本,在特徵空間中拉近;將不相似的樣本,在特徵空間中推開。它不再要求模型去重建或預測像素,而是直接在特徵層面進行學習。
例如常見的 SimCLR (A Simple Framework for Contrastive Learning of Visual Representations) 的流程如下
數據增強:從一個 mini-batch 的未標註圖片中,隨機抽取一張圖片 x。對這張圖片進行兩次不同的隨機數據增強,得到兩個互為「正樣本對 (positive pair)」的視圖 x_i 和 x_j。我們知道,儘管它們在像素層面不同,但它們的語意是完全相同的。
編碼器:使用一個 CNN 編碼器網路 f(),分別提取 x_i 和 x_j 的特徵表示 h_i 和 h_j。這個編碼器是我們要訓練的主體。
投影頭 (projection head):將特徵 h_i 和 h_j,再通過一個小型的 ANN 網路 g()(投影頭),將其映射到一個新的特徵空間中,得到 z_i 和 z_j。在這個空間中,我們將計算對比損失。
計算損失 (contrastive loss): 對於 z_i 來說:
它的「正樣本」是 z_j。
這個 mini-batch 中所有其他圖片生成的圖,都是它的「負樣本 (negative samples)」。
損失函數的目標是,最大化 z_i 和 z_j 之間的相似度,同時,最小化 z_i 與所有負樣本之間的相似度。
這個過程中,編碼器 被迫去學習一種對數據增強不變的特徵表示,它必須忽略掉顏色、旋轉、裁切等表面差異,而去捕捉物體更本質、更抽象的語意資訊。實驗也證明,這個方法預訓練好的編碼器表現非常好。
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
def calculate_similarity(z1, z2):
"""計算兩個張量之間的餘弦相似度"""
z1 = F.normalize(z1, dim=-1)
z2 = F.normalize(z2, dim=-1)
return torch.mm(z1, z2.t())
def nt_xent_loss(z_i, z_j, temperature=0.5):
"""
計算 NT-Xent Loss (SimCLR 中使用的對比損失)。
z_i, z_j: 兩個視圖的投影特徵,尺寸為 [batch_size, feature_dim]
"""
batch_size = z_i.shape[0]
# 1. 將兩個視圖的特徵拼接在一起
z = torch.cat([z_i, z_j], dim=0) # 尺寸變為 [2*batch_size, feature_dim]
# 2. 計算兩兩之間的相似度矩陣
sim_matrix = calculate_similarity(z, z) # 尺寸為 [2*batch_size, 2*batch_size]
# 3. 創建正樣本對的標籤 (mask)
# 正樣本對是 (i, i+N) 和 (i+N, i)
labels = torch.arange(batch_size).to(z.device)
labels = torch.cat([labels + batch_size, labels], dim=0)
# 4. 忽略掉自己和自己的相似度 (對角線元素)
# 建立一個對角線為0,其餘為1的 mask
mask = torch.eye(2 * batch_size, dtype=torch.bool).to(z.device)
# ~mask 是布林反轉
sim_matrix = sim_matrix[~mask].view(2 * batch_size, -1)
# 5. 選出正樣本對的相似度
# 從每一行中,根據 labels 索引,選出正樣本對應的相似度分數
adjusted_labels = labels.clone()
for i in range(2 * batch_size):
if adjusted_labels[i] >= i:
adjusted_labels[i] -= 1
positive_pairs = sim_matrix[torch.arange(2*batch_size), adjusted_labels]
# 6. 計算損失
# logits = [positive_pair_sim, negative_pair_sim_1, negative_pair_sim_2, ...]
logits = torch.cat([positive_pairs.unsqueeze(1), sim_matrix], dim=1)
logits /= temperature
# 正樣本對的標籤永遠是第一個 (索引為0)
loss_labels = torch.zeros(2 * batch_size).long().to(z.device)
loss = F.cross_entropy(logits, loss_labels)
return loss
# --- 演示 ---
if __name__ == '__main__':
# 假設我們有一個 batch_size=2 的 mini-batch
# 圖片 1 (貓), 圖片 2 (狗)
# 假設經過數據增強和編碼器+投影頭後,得到以下特徵
# z1_v1, z1_v2 是貓的兩個視圖的特徵
# z2_v1, z2_v2 是狗的兩個視圖的特徵
z1_v1 = torch.randn(1, 128)
z1_v2 = z1_v1 + 0.1 * torch.randn(1, 128) # 貓的第二個視圖,與第一個相似
z2_v1 = torch.randn(1, 128)
z2_v2 = z2_v1 + 0.1 * torch.randn(1, 128) # 狗的第二個視圖,與第一個相似
# 組成一個 batch
batch_view_1 = torch.cat([z1_v1, z2_v1], dim=0) # [貓v1, 狗v1]
batch_view_2 = torch.cat([z1_v2, z2_v2], dim=0) # [貓v2, 狗v2]
# 計算損失
loss = nt_xent_loss(batch_view_1, batch_view_2, temperature=0.5)
print(f"計算出的對比損失為: {loss.item():.4f}")
# 梯度會驅使 z1_v1 和 z1_v2 的相似度變高,
# 同時驅使 z1_v1 和 z2_v1, z2_v2 的相似度變低。
結果
計算出的對比損失為: 0.7955