今天是第二十九天我們可以寫一個lstm結合yolo v8對於多隻斑馬魚行為分析的最終版本,我認為是我寫得最有效率的程式碼,以下是程式碼
import torch
import cv2
import numpy as np
from sort import Sort
from tensorflow.keras.models import load_model
from concurrent.futures import ThreadPoolExecutor, as_completed
import sqlite3
import logging
# 設置日誌
logging.basicConfig(filename='zebrafish_analysis.log', level=logging.INFO,
format='%(asctime)s:%(levelname)s:%(message)s')
# YOLOv8 檢測模組
class YOLOv8Detector:
def __init__(self):
self.model = torch.hub.load('ultralytics/yolov8', 'yolov8n')
def detect(self, frame):
results = self.model(frame)
return results.xyxy[0]
# 追蹤模組
class FishTracker:
def __init__(self):
self.tracker = Sort()
def track(self, detections):
tracked_objects = self.tracker.update(detections.cpu())
return tracked_objects
# 特徵提取模組
class FeatureExtractor:
def extract(self, tracks):
features = []
for track in tracks:
track_id, x_min, y_min, x_max, y_max = track
center_x = (x_min + x_max) / 2
center_y = (y_min + y_max) / 2
width = x_max - x_min
height = y_max - y_min
# 高階特徵提取,如速度變化率等
speed = np.sqrt((width ** 2 + height ** 2))
features.append([track_id, center_x, center_y, width, height, speed])
return np.array(features)
# 行為分類模組
class BehaviorClassifier:
def __init__(self, model_path):
self.lstm_model = load_model(model_path)
def classify(self, features):
predictions = self.lstm_model.predict(features)
return np.argmax(predictions, axis=1)
# 結果儲存模組
class ResultSaver:
def __init__(self, db_path='zebrafish_results.db'):
self.conn = sqlite3.connect(db_path)
self.cursor = self.conn.cursor()
self.cursor.execute('''CREATE TABLE IF NOT EXISTS results
(track_id INTEGER, behavior TEXT, timestamp TEXT)''')
self.conn.commit()
def save(self, track_id, behavior, timestamp):
self.cursor.execute("INSERT INTO results (track_id, behavior, timestamp) VALUES (?, ?, ?)",
(track_id, behavior, timestamp))
self.conn.commit()
def close(self):
self.conn.close()
# 主分析模組
class ZebrafishAnalyzer:
def __init__(self):
self.detector = YOLOv8Detector()
self.tracker = FishTracker()
self.extractor = FeatureExtractor()
self.classifier = BehaviorClassifier('/path_to_your_model/lstm_model.h5')
self.saver = ResultSaver()
def analyze_frame(self, frame, timestamp):
detections = self.detector.detect(frame)
tracks = self.tracker.track(detections)
features = self.extractor.extract(tracks)
if len(features) > 0:
behaviors = self.classifier.classify(features)
for i, track in enumerate(tracks):
track_id = track[0]
behavior = behaviors[i]
self.saver.save(track_id, behavior, timestamp)
logging.info(f'Track ID: {track_id}, Behavior: {behavior}, Timestamp: {timestamp}')
cv2.putText(frame, f'Behavior: {behavior}', (int(track[1]), int(track[2]) - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
cv2.rectangle(frame, (int(track[1]), int(track[2])),
(int(track[3]), int(track[4])), (255, 0, 0), 2)
return frame
def close(self):
self.saver.close()
# 並行處理設計
def process_video(video_path):
analyzer = ZebrafishAnalyzer()
cap = cv2.VideoCapture(video_path)
with ThreadPoolExecutor(max_workers=4) as executor:
futures = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
timestamp = cap.get(cv2.CAP_PROP_POS_MSEC)
futures.append(executor.submit(analyzer.analyze_frame, frame, timestamp))
if len(futures) > 10:
for future in as_completed(futures):
frame = future.result()
cv2.imshow('Zebrafish Behavior Analysis', frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
cap.release()
cv2.destroyAllWindows()
analyzer.close()
return
for future in as_completed(futures):
frame = future.result()
cv2.imshow('Zebrafish Behavior Analysis', frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
analyzer.close()
# 運行視頻處理
process_video('zebrafish_video.mp4')
我們將程式碼分為幾個模組來分別處理不同的任務:檢測(Detection)、追蹤(Tracking)、特徵提取(Feature Extraction)、行為分類(Behavior Classification)、結果儲存(Result Saving)以及主分析模組(Main Analysis)。這樣設計可以提高程式碼的可維護性,讓各部分可以單獨開發和測試。
YOLOv8Detector
)這個模組負責使用YOLOv8模型來檢測每一幀視頻中的斑馬魚位置。
class YOLOv8Detector:
def __init__(self):
self.model = torch.hub.load('ultralytics/yolov8', 'yolov8n')
def detect(self, frame):
results = self.model(frame)
return results.xyxy[0]
__init__
: 初始化時載入YOLOv8模型。detect
: 給定一幀視頻圖片,使用模型來檢測並返回所有斑馬魚的位置框。FishTracker
)這個模組使用Kalman Filter或DeepSORT來追蹤每一幀中已檢測到的斑馬魚。
class FishTracker:
def __init__(self):
self.tracker = Sort()
def track(self, detections):
tracked_objects = self.tracker.update(detections.cpu())
return tracked_objects
__init__
: 初始化時創建追蹤器對象(這裡使用的是Sort算法)。track
: 接受YOLOv8檢測到的位置,並更新追蹤器以返回每隻魚的追蹤結果。FeatureExtractor
)這個模組從追蹤的斑馬魚軌跡中提取高階特徵,如位置、速度、尺寸等。
class FeatureExtractor:
def extract(self, tracks):
features = []
for track in tracks:
track_id, x_min, y_min, x_max, y_max = track
center_x = (x_min + x_max) / 2
center_y = (y_min + y_max) / 2
width = x_max - x_min
height = y_max - y_min
# 高階特徵提取,如速度變化率等
speed = np.sqrt((width ** 2 + height ** 2))
features.append([track_id, center_x, center_y, width, height, speed])
return np.array(features)
extract
: 根據追蹤器的結果,提取斑馬魚的高階特徵,包括位置(中心點)、尺寸(寬度、高度)和速度(計算出的速度值)。BehaviorClassifier
)這個模組使用事先訓練好的LSTM模型來根據特徵分類斑馬魚的行為。
class BehaviorClassifier:
def __init__(self, model_path):
self.lstm_model = load_model(model_path)
def classify(self, features):
predictions = self.lstm_model.predict(features)
return np.argmax(predictions, axis=1)
__init__
: 載入LSTM模型,用於分類行為。classify
: 根據提取的特徵進行行為分類,並返回預測的行為標籤。ResultSaver
)這個模組將分析結果儲存在SQLite數據庫中,以便日後查詢和分析。
class ResultSaver:
def __init__(self, db_path='zebrafish_results.db'):
self.conn = sqlite3.connect(db_path)
self.cursor = self.conn.cursor()
self.cursor.execute('''CREATE TABLE IF NOT EXISTS results
(track_id INTEGER, behavior TEXT, timestamp TEXT)''')
self.conn.commit()
def save(self, track_id, behavior, timestamp):
self.cursor.execute("INSERT INTO results (track_id, behavior, timestamp) VALUES (?, ?, ?)",
(track_id, behavior, timestamp))
self.conn.commit()
def close(self):
self.conn.close()
__init__
: 連接到SQLite數據庫並創建儲存結果的表格。save
: 儲存每一隻斑馬魚的行為預測結果以及時間戳到數據庫。close
: 關閉數據庫連接。ZebrafishAnalyzer
)這個模組將所有其他模組整合起來,處理每一幀視頻,完成斑馬魚行為分析的全過程。
class ZebrafishAnalyzer:
def __init__(self):
self.detector = YOLOv8Detector()
self.tracker = FishTracker()
self.extractor = FeatureExtractor()
self.classifier = BehaviorClassifier('/path_to_your_model/lstm_model.h5')
self.saver = ResultSaver()
def analyze_frame(self, frame, timestamp):
detections = self.detector.detect(frame)
tracks = self.tracker.track(detections)
features = self.extractor.extract(tracks)
if len(features) > 0:
behaviors = self.classifier.classify(features)
for i, track in enumerate(tracks):
track_id = track[0]
behavior = behaviors[i]
self.saver.save(track_id, behavior, timestamp)
logging.info(f'Track ID: {track_id}, Behavior: {behavior}, Timestamp: {timestamp}')
cv2.putText(frame, f'Behavior: {behavior}', (int(track[1]), int(track[2]) - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
cv2.rectangle(frame, (int(track[1]), int(track[2])),
(int(track[3]), int(track[4])), (255, 0, 0), 2)
return frame
def close(self):
self.saver.close()
__init__
: 初始化各個模組。analyze_frame
: 對單幀圖片進行全過程的分析,檢測、追蹤、提取特徵、分類行為並儲存結果,還會在影像上標註行為。close
: 在分析結束後關閉儲存模組。這部分使用了Python的ThreadPoolExecutor
來進行並行處理,使得多幀的處理可以同時進行,提高效率。
def process_video(video_path):
analyzer = ZebrafishAnalyzer()
cap = cv2.VideoCapture(video_path)
with ThreadPoolExecutor(max_workers=4) as executor:
futures = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
timestamp = cap.get(cv2.CAP_PROP_POS_MSEC)
futures.append(executor.submit(analyzer.analyze_frame, frame, timestamp))
if len(futures) > 10:
for future in as_completed(futures):
frame = future.result()
cv2.imshow('Zebrafish Behavior Analysis', frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
cap.release()
cv2.destroyAllWindows()
analyzer.close()
return
for future in as_completed(futures):
frame = future.result()
cv2.imshow('Zebrafish Behavior Analysis', frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
analyzer.close()
process_video
: 處理整個視頻,使用多線程來並行處理多幀影像,當累積一定數量的處理結果後就展示在視窗中。在特徵提取部分,除了基本的空間特徵外,還加入了速度這樣的動態特徵,這些特徵有助於更準確的行為分類。
使用SQLite數據庫來儲存結果,這樣可以方便地查詢和進行後續的統計分析。
日誌記錄可以幫助跟蹤程式的運行狀態和結果,特別是在處理大批量數據時非常有用。
這整個程式是一個高度模組化、功能豐富的斑馬魚行為分析系統,適用於科研或高階應用場景。