iT邦幫忙

2025 iThome 鐵人賽

DAY 16
0
AI & Data

從0開始:傳統圖像處理到深度學習模型系列 第 16

Day 16 - 卷積神經網路(四)Grad-CAM

  • 分享至 

  • xImage
  •  

雖然我們學會了怎麼用各種不同 CNN 模型來進行圖像辨認,但對他的認識仍停留在黑盒子的階段。模型在做出「這是一隻貓」的決斷時,還是不知道它是依據什麼做判斷的,而這點關乎到對這模型的信任程度。因此今天我們要對他進行視覺化,來知道模型究竟是注意到圖像哪個區域,來做出最後的決策。

Grad-CAM

Grad-CAM (Gradient-weighted Class Activation Mapping) 是一種無需修改模型結構、可以應用於幾乎所有 CNN 架構的視覺化技術。它使用梯度來判斷重要性,具體流程如下

  1. 前向傳播:將一張輸入圖片餵給訓練好的 CNN 模型,進行一次完整的前向傳播,得到最終的分類分數(在 softmax 之前)。

  2. 選定目標:我們選定一個我們感興趣的類別,例如,我們想知道模型為何將圖片判斷為「黃金獵犬」。

  3. 計算梯度:我們計算出這個「黃金獵犬」類別的分數,對於模型最後一個卷積層的每一張特徵圖的梯度。這個梯度,直觀地反映了「如果我稍微改變這張特徵圖上某個像素的值,會對最終的『黃金獵犬』分數產生多大的影響」。

  4. 計算權重:對每一張特徵圖的梯度圖,進行全域平均池化 (global average pooling),得到一個純量值。這個值,就被當作是這張特徵圖對於「黃金獵犬」這個類別的重要性權重。梯度越大,權重越高。

  5. 加權求和:將最後一個卷積層的所有特徵圖,根據它們各自計算出的重要性權重進行加權求和。這就得到了一張粗糙的、融合了所有重要特徵的熱圖 (heatmap)。

  6. 視覺化:將這張熱圖放大到與原始輸入圖片相同的大小,並將其用偽彩色疊加在原始圖片上。熱圖中顏色越熱的區域,就是模型在做出該類別判斷時,最關注的部分。

應用 Grad-CAM 於 ResNet50

import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import requests
from io import BytesIO
import numpy as np
import cv2
import matplotlib.pyplot as plt

# --- 1. 載入模型和準備圖片 ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet50(pretrained=True).to(device)
model.eval()

# 找到 ResNet-50 的最後一個卷積層 (通常是 layer4 的最後一個 bottleneck)
target_layer = model.layer4[-1]

# 圖片 URL
image_url = "https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg" 

response = requests.get(image_url)
img_pil = Image.open(BytesIO(response.content)).convert('RGB')

preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(img_pil).unsqueeze(0).to(device)

# --- 2. 實現 Grad-CAM ---
# 我們需要儲存目標層的特徵圖和梯度
feature_maps = []
gradients = []

def save_feature_map(module, input, output):
    feature_maps.append(output)

def save_gradient(module, grad_in, grad_out):
    gradients.append(grad_out[0])

# 使用 "hooks" 來掛載我們的儲存函式
# handle_forward 會在 target_layer 執行完 forward 後被呼叫
# handle_backward 會在 target_layer 計算完 backward 梯度後被呼叫
handle_forward = target_layer.register_forward_hook(save_feature_map)
handle_backward = target_layer.register_backward_hook(save_gradient)

# 前向傳播
output = model(input_tensor)
# 找到預測分數最高的類別作為我們的目標
target_category = torch.argmax(output, dim=1).item()
print(f"預測的類別索引: {target_category}")

# 反向傳播
# 我們只關心目標類別的分數,所以創建一個 one-hot 張量
one_hot = torch.zeros_like(output)
one_hot[0][target_category] = 1
model.zero_grad()
output.backward(gradient=one_hot, retain_graph=True)

# 移除 hooks
handle_forward.remove()
handle_backward.remove()

# --- 3. 計算並生成熱圖 ---
# gradients[0] 的尺寸: [batch, channels, height, width]
# feature_maps[0] 的尺寸: [batch, channels, height, width]

# 4. 計算權重 (全域平均池化)
pooled_gradients = torch.mean(gradients[0], dim=[0, 2, 3]) # 對 H, W 維度求平均

# 5. 加權求和
activations = feature_maps[0].squeeze(0)
for i in range(activations.shape[0]):
    activations[i, :, :] *= pooled_gradients[i]

# 計算熱圖 (在通道維度上求和)
heatmap = torch.mean(activations, dim=0).cpu().detach().numpy()

# ReLU 操作,只保留正值
heatmap = np.maximum(heatmap, 0)
# 正規化
heatmap /= np.max(heatmap)

# --- 4. 視覺化 ---
img_cv = cv2.cvtColor(np.array(img_pil.resize((224, 224))), cv2.COLOR_RGB2BGR)
heatmap_cv = cv2.resize(heatmap, (img_cv.shape[1], img_cv.shape[0]))
heatmap_cv = np.uint8(255 * heatmap_cv)
heatmap_color = cv2.applyColorMap(heatmap_cv, cv2.COLORMAP_JET)

superimposed_img = heatmap_color * 0.4 + img_cv
superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)

# 顯示結果
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(img_pil.resize((224,224)))
plt.title("Original Image")
plt.axis('off')

plt.subplot(1, 2, 2)
# Matplotlib 顯示 BGR 圖片需要轉回 RGB
plt.imshow(cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB))
plt.title("Grad-CAM")
plt.axis('off')

plt.show()

結果
https://ithelp.ithome.com.tw/upload/images/20250826/20178100oTl7FtAYlh.png


上一篇
Day 15 - 卷積神經網路(三) GoogLeNet 與 ResNet
下一篇
Day 17 - 卷積神經網路(五)MobileNet
系列文
從0開始:傳統圖像處理到深度學習模型23
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言