iT邦幫忙

2025 iThome 鐵人賽

DAY 10
0
AI & Data

從0開始:傳統圖像處理到深度學習模型系列 第 10

Day 10 - 機器學習初探(一) KNN

  • 分享至 

  • xImage
  •  

機器學習簡介

我們對傳統電腦視覺領域中,基於幾何和梯度的方法論,已經有了非常深入的理解和實踐。這些方法在處理具有明確規則和幾何結構的任務時,表現得非常出色,但如果問題變成「這張圖片裡的動物是貓還是狗」呢?此時我們就需要用到機器學習的概念,讓電腦學習規則。

機器學習與傳統程式設計的差別如下

  • 傳統程式設計: 資料 + 程式(規則)→ 輸出

  • 機器學習: 資料 + 輸出(標籤)→ 程式(模型/規則)

機器學習之下大致上又可以分為四大類

  1. 監督式學習:最常見的一種,我們提供給機器的訓練資料,不僅包含「輸入」(例如貓狗圖片),還包含我們人工標註好的「正確答案」(例如標籤貓或狗)。

  2. 無監督式學習:我們只提供「輸入」資料,沒有任何「正確答案」。機器的目標是自己去探索資料中潛在的結構和模式,例如分群 (clustering)

  3. 半監督式學習:介於前兩者之間,訓練資料中只有一小部分有標籤,大部分沒有。

  4. 強化學習:機器在一個環境中,透過不斷地「嘗試」和「從錯誤中學習」來做出決策,以最大化它能獲得的獎勵

KNN

K-近鄰演算法 (K-Nearest Neighbors, KNN) 是最簡單、最直觀的機器學習演算法。

當我們有一個新的、未知的數據點需要分類時,KNN 的流程是

  • 計算這個新數據點與所有已知訓練數據點之間的「距離」。

  • 找出距離最近的 K 個鄰居。

  • 在這 K 個鄰居中,進行投票。哪個類別的鄰居最多,我們就預測這個新數據點屬於哪個類別。

KNN 分別有以下優缺點

  • 優點:演算法非常簡單,無需訓練階段。

  • 缺點:預測階段的計算成本非常高,因為需要計算與所有訓練點的距離。K 值亦不宜過大或過小。

辨識手寫數字

首先先安裝 scikit-learn。

pip install scikit-learn

我們可以使用 MNIST 手寫數字資料集來進行我們的辨識任務

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler

# --- 1. 載入並準備數據 ---
print("正在載入 MNIST 數據集,請稍候...")
# fetch_openml 會從網路上抓取數據集
# 每一張圖片是 28x28 = 784 個像素點,所以 X 的維度是 (樣本數, 784)
# y 是每個樣本對應的標籤 ('0', '1', ..., '9')
X, y = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=False, parser='liac-arff')
print("數據載入完成!")

# 為了加快執行速度,我們只使用數據集的一部分
X_sample = X[:10000]
y_sample = y[:10000]

# 將數據集分為「訓練集」和「測試集」
# 訓練集用來「學習」,測試集用來評估模型的表現
# test_size=0.2 表示 20% 的數據用作測試
X_train, X_test, y_train, y_test = train_test_split(X_sample, y_sample, test_size=0.2, random_state=42)

# --- 2. 特徵縮放 ---
# 像素值的範圍是 0-255,尺度差異大,需要進行標準化
# 標準化:(x - mean) / std_dev
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test) # 注意:測試集使用訓練集的 scaler 進行 transform

# --- 3. 建立並訓練 KNN 模型 ---
# n_neighbors 就是超參數 K
# n_jobs=-1 表示使用所有可用的 CPU 核心來加速計算距離
print("正在建立 KNN 模型 (K=5)...")
knn = KNeighborsClassifier(n_neighbors=5, n_jobs=-1)
knn.fit(X_train_scaled, y_train) # KNN 的 "fit" 只是把訓練數據儲存起來
print("模型建立完畢!")

# --- 4. 進行預測 ---
print("正在對測試集進行預測...")
y_pred = knn.predict(X_test_scaled)
print("預測完成!")

# --- 5. 評估模型表現 ---
accuracy = accuracy_score(y_test, y_pred)
print(f"\n模型在測試集上的準確率: {accuracy * 100:.2f}%")

# --- 6. 視覺化預測結果 ---
# 隨機選擇幾個測試樣本來看看預測得對不對
plt.figure(figsize=(12, 5))
for i in range(10):
    # 從測試集中隨機選一個樣本
    idx = np.random.randint(0, len(X_test))
    image = X_test[idx].reshape(28, 28)
    true_label = y_test[idx]
    
    # 獲取模型對這個單一樣本的預測
    predicted_label = knn.predict(X_test_scaled[idx].reshape(1, -1))[0]
    
    plt.subplot(2, 5, i + 1)
    plt.imshow(image, cmap='gray')
    plt.title(f"True: {true_label}\nPred: {predicted_label}",
              color=("green" if true_label == predicted_label else "red"))
    plt.axis('off')

plt.tight_layout()
plt.show()

結果

模型在測試集上的準確率: 91.20%

https://ithelp.ithome.com.tw/upload/images/20250820/20178100mmz0yKj7k7.png


上一篇
Day 9 – RANSAC 與全景照片
下一篇
Day 11 - 機器學習初探(二) HOG 與 SVM
系列文
從0開始:傳統圖像處理到深度學習模型23
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言