iT邦幫忙

2024 iThome 鐵人賽

DAY 28
0
AI/ ML & Data

基於人工智慧與深度學習對斑馬魚做行為分析系列 第 28

day 28 基於人工智慧與深度學習斑馬魚行為分析

  • 分享至 

  • xImage
  •  

今天是第二十八天我們可以寫一個人工智慧與深度學習斑馬魚行為分析的高階版程式,以下是程式碼

import cv2
import numpy as np
import torch
from ultralytics import YOLO
from keras.models import Sequential, load_model
from keras.layers import LSTM, Dense, Dropout
from scipy.optimize import linear_sum_assignment
from collections import deque

# Kalman Filter class for object tracking
class KalmanTracker:
    def __init__(self):
        self.kalman = cv2.KalmanFilter(4, 2)
        self.kalman.measurementMatrix = np.array([[1, 0, 0, 0],
                                                  [0, 1, 0, 0]], np.float32)
        self.kalman.transitionMatrix = np.array([[1, 0, 1, 0],
                                                 [0, 1, 0, 1],
                                                 [0, 0, 1, 0],
                                                 [0, 0, 0, 1]], np.float32)
        self.kalman.processNoiseCov = np.array([[1, 0, 0, 0],
                                                [0, 1, 0, 0],
                                                [0, 0, 1, 0],
                                                [0, 0, 0, 1]], np.float32) * 0.03
        self.prediction = np.zeros((2, 1), np.float32)

    def update(self, coord):
        measurement = np.array([[np.float32(coord[0])],
                                [np.float32(coord[1])]])
        self.kalman.correct(measurement)
        self.prediction = self.kalman.predict()
        return self.prediction

# Load YOLOv8 model
yolo_model = YOLO("yolov8n.pt")  # 替換為你的YOLOv8模型權重

# Load pre-trained LSTM model
lstm_model = load_model('/path/to/your/lstm_model.h5')

# Multi-layer LSTM model definition (if you want to build from scratch)
def build_lstm_model(input_shape):
    model = Sequential()
    model.add(LSTM(64, return_sequences=True, input_shape=input_shape))
    model.add(Dropout(0.2))
    model.add(LSTM(64, return_sequences=False))
    model.add(Dropout(0.2))
    model.add(Dense(32, activation='relu'))
    model.add(Dense(3, activation='softmax'))  # 假設你有三種類別的行為
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

# Set YOLOv8 input size
input_size = 640

# Initialize trackers
trackers = {}
tracker_id = 0

# Trajectory data storage
trajectory_data = {}

# Start processing video
cap = cv2.VideoCapture('/path/to/your/video.mp4')

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    # YOLOv8 detection
    results = yolo_model.predict(frame, imgsz=input_size)
    
    # Parse detection results
    detected_coords = []
    for result in results:
        boxes = result.boxes
        for box in boxes:
            x1, y1, x2, y2 = map(int, box.xyxy)
            cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
            detected_coords.append((cx, cy))
            cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 0), 2)
    
    # Assign detections to existing trackers using Hungarian algorithm
    tracker_ids = list(trackers.keys())
    if tracker_ids:
        cost_matrix = np.zeros((len(tracker_ids), len(detected_coords)))
        for i, t_id in enumerate(tracker_ids):
            pred = trackers[t_id].prediction
            for j, coord in enumerate(detected_coords):
                cost_matrix[i, j] = np.linalg.norm(pred - np.array([[coord[0]], [coord[1]]]))
        
        row_inds, col_inds = linear_sum_assignment(cost_matrix)
        assigned_detections = set()
        for r, c in zip(row_inds, col_inds):
            if cost_matrix[r, c] < 50:  # Threshold for assignment
                trackers[tracker_ids[r]].update(detected_coords[c])
                assigned_detections.add(c)
        
        # Handle unassigned detections
        for j, coord in enumerate(detected_coords):
            if j not in assigned_detections:
                trackers[tracker_id] = KalmanTracker()
                trackers[tracker_id].update(coord)
                tracker_id += 1
    else:
        # If no trackers exist, initialize them with the detections
        for coord in detected_coords:
            trackers[tracker_id] = KalmanTracker()
            trackers[tracker_id].update(coord)
            tracker_id += 1
    
    # Update trajectories
    for t_id in list(trackers.keys()):
        pred = trackers[t_id].prediction
        if t_id not in trajectory_data:
            trajectory_data[t_id] = deque(maxlen=9)
        
        trajectory_data[t_id].append([pred[0, 0], pred[1, 0]])
        
        # If trajectory is full, predict behavior
        if len(trajectory_data[t_id]) == trajectory_data[t_id].maxlen:
            data = np.array(trajectory_data[t_id]).reshape(1, trajectory_data[t_id].maxlen, 2)
            prediction = lstm_model.predict(data)
            behavior = np.argmax(prediction)
            cv2.putText(frame, f'ID: {t_id}, Behavior: {behavior}', (int(pred[0, 0]), int(pred[1, 0])), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
    
    cv2.imshow('Zebrafish Behavior Analysis', frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()

1. Kalman濾波器追蹤 (KalmanTracker 類別)

class KalmanTracker:
    def __init__(self):
        self.kalman = cv2.KalmanFilter(4, 2)
        self.kalman.measurementMatrix = np.array([[1, 0, 0, 0],
                                                  [0, 1, 0, 0]], np.float32)
        self.kalman.transitionMatrix = np.array([[1, 0, 1, 0],
                                                 [0, 1, 0, 1],
                                                 [0, 0, 1, 0],
                                                 [0, 0, 0, 1]], np.float32)
        self.kalman.processNoiseCov = np.array([[1, 0, 0, 0],
                                                [0, 1, 0, 0],
                                                [0, 0, 1, 0],
                                                [0, 0, 0, 1]], np.float32) * 0.03
        self.prediction = np.zeros((2, 1), np.float32)

    def update(self, coord):
        measurement = np.array([[np.float32(coord[0])],
                                [np.float32(coord[1])]])
        self.kalman.correct(measurement)
        self.prediction = self.kalman.predict()
        return self.prediction
  • 這部分定義了一個KalmanTracker類別,用於追蹤斑馬魚的位置。
  • KalmanFilter的輸入是4維狀態空間(位置和速度),以及2維測量空間(位置)。
  • measurementMatrix 是從測量到狀態的映射,transitionMatrix 描述了狀態的演變。
  • processNoiseCov 定義了過程噪聲的協方差矩陣,用來控制濾波器對不確定性的敏感度。
  • update()方法接收新的觀測位置(coord),用Kalman濾波器進行狀態更新,並返回預測的下一個位置。

2. YOLOv8模型加載

yolo_model = YOLO("yolov8n.pt")
  • 使用ultralytics庫來加載YOLOv8模型,這裡使用的是yolov8n.pt模型權重。該模型會用於偵測斑馬魚在每一幀中的位置。

3. LSTM模型加載與構建

lstm_model = load_model('/path/to/your/lstm_model.h5')

def build_lstm_model(input_shape):
    model = Sequential()
    model.add(LSTM(64, return_sequences=True, input_shape=input_shape))
    model.add(Dropout(0.2))
    model.add(LSTM(64, return_sequences=False))
    model.add(Dropout(0.2))
    model.add(Dense(32, activation='relu'))
    model.add(Dense(3, activation='softmax'))
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model
  • load_model()函數用來加載已經訓練好的LSTM模型,用於預測斑馬魚的行為。
  • build_lstm_model()函數展示了如何從頭構建一個多層LSTM模型,包含兩層LSTM和兩個Dropout層,最後是兩個全連接層(Dense)。輸出層使用softmax來進行多分類輸出(假設有三種行為類別)。

4. 主要的處理流程

input_size = 640
trackers = {}
tracker_id = 0
trajectory_data = {}

cap = cv2.VideoCapture('/path/to/your/video.mp4')
  • 設定了YOLOv8模型的輸入尺寸為640
  • trackers字典用來保存所有斑馬魚的追蹤器(即KalmanTracker實例)。
  • trajectory_data字典用來存儲每條斑馬魚的軌跡數據。

5. 主循環:處理每一幀影像

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    results = yolo_model.predict(frame, imgsz=input_size)
    detected_coords = []
    for result in results:
        boxes = result.boxes
        for box in boxes:
            x1, y1, x2, y2 = map(int, box.xyxy)
            cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
            detected_coords.append((cx, cy))
            cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 0), 2)
  • 這段代碼處理每一幀影像,從影片中讀取幀,並將其輸入YOLOv8模型進行斑馬魚偵測。
  • 偵測結果包括斑馬魚的邊框位置(x1, y1, x2, y2),用於計算中心點座標(cx, cy),並將偵測到的中心點存入detected_coords列表。

6. 匈牙利算法匹配追蹤器

tracker_ids = list(trackers.keys())
if tracker_ids:
    cost_matrix = np.zeros((len(tracker_ids), len(detected_coords)))
    for i, t_id in enumerate(tracker_ids):
        pred = trackers[t_id].prediction
        for j, coord in enumerate(detected_coords):
            cost_matrix[i, j] = np.linalg.norm(pred - np.array([[coord[0]], [coord[1]]]))

    row_inds, col_inds = linear_sum_assignment(cost_matrix)
    assigned_detections = set()
    for r, c in zip(row_inds, col_inds):
        if cost_matrix[r, c] < 50:  # Threshold for assignment
            trackers[tracker_ids[r]].update(detected_coords[c])
            assigned_detections.add(c)

    for j, coord in enumerate(detected_coords):
        if j not in assigned_detections:
            trackers[tracker_id] = KalmanTracker()
            trackers[tracker_id].update(coord)
            tracker_id += 1
else:
    for coord in detected_coords:
        trackers[tracker_id] = KalmanTracker()
        trackers[tracker_id].update(coord)
        tracker_id += 1
  • 這段代碼使用匈牙利算法(線性分配問題)來將新的偵測結果與現有的追蹤器進行匹配。
  • cost_matrix表示追蹤器預測位置與新偵測位置之間的距離矩陣。
  • linear_sum_assignment函數用來找到總成本最小的匹配組合,結果是每個追蹤器與偵測位置的最佳配對。
  • 若距離小於門檻值(50像素),則更新追蹤器的位置;若有新的偵測位置未被分配,則為其創建一個新的Kalman追蹤器。

7. 更新軌跡數據並預測行為

for t_id in list(trackers.keys()):
    pred = trackers[t_id].prediction
    if t_id not in trajectory_data:
        trajectory_data[t_id] = deque(maxlen=9)
    
    trajectory_data[t_id].append([pred[0, 0], pred[1, 0]])

    if len(trajectory_data[t_id]) == trajectory_data[t_id].maxlen:
        data = np.array(trajectory_data[t_id]).reshape(1, trajectory_data[t_id].maxlen, 2)
        prediction = lstm_model.predict(data)
        behavior = np.argmax(prediction)
        cv2.putText(frame, f'ID: {t_id}, Behavior: {behavior}', (int(pred[0, 0]), int(pred[1, 0])), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

cv2.imshow('Zebrafish Behavior Analysis', frame)
  • 每個追蹤器的預測位置存入對應的trajectory_data
  • 當某個追蹤器的軌跡數據達到9個點(滿足LSTM模型的輸入需求)時,將其轉換為適當的形狀並輸入LSTM模型進行行為預測。
  • 預測結果為行為類別(behavior),並將其顯示在影像上。

8.

結束與清理

cap.release()
cv2.destroyAllWindows()
  • 這段代碼在影片處理結束後釋放影片資源並關閉所有OpenCV視窗。

這個程式結合了物體偵測(YOLOv8)、追蹤(Kalman濾波器)、與行為預測(LSTM)來分析斑馬魚的行為。通過匈牙利算法進行追蹤器與偵測位置的匹配,以及使用多層LSTM來處理複雜的時間序列數據,這個程式能夠在動態場景中準確預測斑馬魚的行為。


上一篇
Day 27lstm 多隻斑馬魚所有行為分析並預測斑馬魚身體狀態
下一篇
day 29 基於人工智慧與深度學習對斑馬魚做行為分析
系列文
基於人工智慧與深度學習對斑馬魚做行為分析30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言