iT邦幫忙

2024 iThome 鐵人賽

DAY 6
0

今天是第六天我想寫一個yolo 模型的訓練來標記我的斑馬魚

import cv2
import numpy as np
import tkinter as tk
from tkinter import filedialog, messagebox
from tkinter import ttk

def load_yolo_model(weight_file, cfg_file, names_file):
    net = cv2.dnn.readNet(weight_file, cfg_file)
    with open(names_file, 'r') as f:
        classes = f.read().strip().split("\n")
    return net, classes

def yolo_calculate(image_path, weight_file_path, cfg_file_path, names_file_path):
    # 加載 YOLO 模型
    net, classes = load_yolo_model(weight_file_path, cfg_file_path, names_file_path)
    
    # 讀取影像
    img = cv2.imread(image_path)
    height, width = img.shape[:2]
    
    # 進行 YOLO 目標偵測
    blob = cv2.dnn.blobFromImage(img, 1/255.0, (416, 416), swapRB=True, crop=False)
    net.setInput(blob)
    layer_names = net.getLayerNames()
    output_layers = [layer_names[i - 1] for i in net.getUnconnectedOutLayers()]
    detections = net.forward(output_layers)
    
    # 處理 YOLO 輸出
    boxes = []
    confidences = []
    class_ids = []
    
    for output in detections:
        for detection in output:
            scores = detection[5:]
            class_id = np.argmax(scores)
            confidence = scores[class_id]
            if confidence > 0.5 and classes[class_id] == 'zebrafish':  # 偵測到斑馬魚
                box = detection[0:4] * np.array([width, height, width, height])
                (centerX, centerY, w, h) = box.astype("int")
                x = int(centerX - (w / 2))
                y = int(centerY - (h / 2))
                
                boxes.append([x, y, int(w), int(h)])
                confidences.append(float(confidence))
                class_ids.append(class_id)
    
    # 非極大值抑制以消除多餘的重疊框
    indices = cv2.dnn.NMSBoxes(boxes, confidences, score_threshold=0.5, nms_threshold=0.4)
    
    # 繪製標記框
    if len(indices) > 0:
        for i in indices.flatten():
            x, y, w, h = boxes[i]
            label = f"{classes[class_ids[i]]}: {confidences[i]:.2f}"
            cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 2)
            cv2.putText(img, label, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
    
    # 儲存並返回處理過的影像
    result_image_path = "yolo_output.png"
    cv2.imwrite(result_image_path, img)
    return result_image_path

# 將此函數整合到原本的 perform_yolo 函數中:
def perform_yolo():
    file_path = file_path_var.get()
    weight_file_path = weight_value_var.get()
    cfg_file_path = "yolov4.cfg"  # 替換為你的 cfg 檔案路徑
    names_file_path = "coco.names"  # 替換為你的名稱檔案路徑

    result_image_path = yolo_calculate(file_path, weight_file_path, cfg_file_path, names_file_path)
    load_image(result_image_path)
    show_zebrafish_window()
    ```
    

### 1. **匯入需要的模組**
```python
import cv2
import numpy as np
import tkinter as tk
from tkinter import filedialog, messagebox
from tkinter import ttk

這裡匯入了所需的模組:

  • cv2 是 OpenCV,用來進行影像處理和目標偵測。
  • numpy (np) 是用於數值計算的工具,這裡用來處理 YOLO 的輸出。
  • tkinter 是 Python 的標準 GUI 庫,用來建立圖形界面應用程式。
  • filedialogmessageboxtkinter 模組中的子模組,分別用於文件選擇對話框和訊息框。
  • ttktkinter 的子模組,提供了更多樣式的元件(如按鈕、標籤等)。

2. YOLO 模型加載函數

def load_yolo_model(weight_file, cfg_file, names_file):
    net = cv2.dnn.readNet(weight_file, cfg_file)
    with open(names_file, 'r') as f:
        classes = f.read().strip().split("\n")
    return net, classes

這個函數用於加載 YOLO 模型和分類名稱:

  • weight_file:YOLO 的權重文件。
  • cfg_file:YOLO 的配置文件,定義了模型的架構。
  • names_file:存儲物件名稱的文件,每行一個名稱。

cv2.dnn.readNet 函數會根據權重文件和配置文件建立一個深度學習網路模型 (net)。接著,打開並讀取名稱文件 (names_file),將每個名稱存儲到一個列表中。

3. YOLO 計算和標記函數

def yolo_calculate(image_path, weight_file_path, cfg_file_path, names_file_path):
    # 加載 YOLO 模型
    net, classes = load_yolo_model(weight_file_path, cfg_file_path, names_file_path)
    
    # 讀取影像
    img = cv2.imread(image_path)
    height, width = img.shape[:2]

這裡是主函數 yolo_calculate,用來處理影像並標記斑馬魚:

  • image_path 是要處理的影像路徑。
  • 使用 cv2.imread 讀取影像並獲取其寬度和高度。
    # 進行 YOLO 目標偵測
    blob = cv2.dnn.blobFromImage(img, 1/255.0, (416, 416), swapRB=True, crop=False)
    net.setInput(blob)
    layer_names = net.getLayerNames()
    output_layers = [layer_names[i - 1] for i in net.getUnconnectedOutLayers()]
    detections = net.forward(output_layers)

這部分是進行 YOLO 偵測:

  • 將影像轉換成 YOLO 可以處理的輸入格式 (blob)。
  • net.setInput 設定輸入。
  • net.getLayerNames 獲取網路中所有的層名稱,net.getUnconnectedOutLayers 獲取需要的輸出層。
  • net.forward 執行前向傳播,獲取偵測結果。
    # 處理 YOLO 輸出
    boxes = []
    confidences = []
    class_ids = []
    
    for output in detections:
        for detection in output:
            scores = detection[5:]
            class_id = np.argmax(scores)
            confidence = scores[class_id]
            if confidence > 0.5 and classes[class_id] == 'zebrafish':  # 偵測到斑馬魚
                box = detection[0:4] * np.array([width, height, width, height])
                (centerX, centerY, w, h) = box.astype("int")
                x = int(centerX - (w / 2))
                y = int(centerY - (h / 2))
                
                boxes.append([x, y, int(w), int(h)])
                confidences.append(float(confidence))
                class_ids.append(class_id)

這部分處理 YOLO 的輸出:

  • 每個偵測結果包含物件的位置信息(前四個值)和各類別的信心分數。
  • 取分數最高的類別 (class_id),並確定信心值 (confidence) 大於 0.5 且該類別為斑馬魚 (zebrafish) 才進行後續處理。
  • 計算邊界框的位置與大小,並將其存入 boxes 列表中。信心值和類別 ID 分別存入 confidencesclass_ids
    # 非極大值抑制以消除多餘的重疊框
    indices = cv2.dnn.NMSBoxes(boxes, confidences, score_threshold=0.5, nms_threshold=0.4)
    
    # 繪製標記框
    if len(indices) > 0:
        for i in indices.flatten():
            x, y, w, h = boxes[i]
            label = f"{classes[class_ids[i]]}: {confidences[i]:.2f}"
            cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 2)
            cv2.putText(img, label, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
    
    # 儲存並返回處理過的影像
    result_image_path = "yolo_output.png"
    cv2.imwrite(result_image_path, img)
    return result_image_path

這段程式碼進行了兩個步驟:

  1. 非極大值抑制 (NMS):過濾掉多餘的重疊框,確保每個物件只保留一個最好的框。
  2. 繪製標記框:在圖像上繪製邊界框和標籤,並將處理後的圖像保存為 yolo_output.png

最後,返回處理過的圖像路徑。

4. 整合到 tkinter 的按鈕函數

def perform_yolo():
    file_path = file_path_var.get()
    weight_file_path = weight_value_var.get()
    cfg_file_path = "yolov4.cfg"  # 替換為你的 cfg 檔案路徑
    names_file_path = "coco.names"  # 替換為你的名稱檔案路徑

    result_image_path = yolo_calculate(file_path, weight_file_path, cfg_file_path, names_file_path)
    load_image(result_image_path)
    show_zebrafish_window()

這個函數是當你在 GUI 中按下按鈕時觸發的事件:

  • 獲取使用者選擇的圖像路徑和權重文件路徑。
  • 調用 yolo_calculate 函數來進行斑馬魚標記。
  • 最後顯示處理後的圖像並打開新的斑馬魚分析視窗。

總結

這段程式碼展示了如何使用 YOLO 模型自動標記斑馬魚,並將其整合到 tkinter 的 GUI 應用程式中。主要流程是:

  1. 加載 YOLO 模型。
  2. 將圖像進行目標偵測。
  3. 繪製標記框並顯示結果。

上一篇
day 5 yolo多隻斑馬魚的行為分析
下一篇
day 7 yolo v8 模擬特斯拉辨識系統
系列文
基於人工智慧與深度學習對斑馬魚做行為分析30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言