今天是第二十五天我們可以寫一個lstm結合yolo去分析斑馬的模型系統,我們可以先看我們的模型準不準,以下是程式碼
首先,YOLOv8需要用於檢測斑馬魚的位置。你可以使用Ultralytics YOLOv8
框架進行模型的加載和檢測。
from ultralytics import YOLO
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout, BatchNormalization
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
import os
# YOLOv8 模型初始化
yolo_model = YOLO('yolov8n.pt') # 可以選擇不同的YOLOv8模型,根據需求選擇n, s, m, l, x等模型
# 檢測斑馬魚的位置
def detect_fish(frame):
results = yolo_model(frame) # 進行目標檢測
detections = []
for result in results.xyxy[0].cpu().numpy(): # 提取檢測結果
x_min, y_min, x_max, y_max, confidence, class_id = result
if class_id == 0: # 假設斑馬魚的class_id為0
x_center = (x_min + x_max) / 2
y_center = (y_min + y_max) / 2
detections.append((x_center, y_center))
return detections
YOLOv8模型檢測到斑馬魚的位置後,我們需要將其轉換為LSTM模型的輸入格式。這部分代碼將視頻幀或圖像序列轉換為LSTM所需的數據。
def prepare_data(frames, lookback=9):
X, y = [], []
for i in range(len(frames) - lookback):
input_seq = []
for j in range(lookback):
detected_fish = detect_fish(frames[i + j])
input_seq.append(detected_fish)
X.append(input_seq)
y.append(detect_fish(frames[i + lookback])) # 下一幀的真實位置
return np.array(X), np.array(y)
def load_video_frames(video_path):
cap = cv2.VideoCapture(video_path)
frames = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frames.append(frame)
cap.release()
return frames
接下來,我們會構建一個更複雜的LSTM模型,並使用檢測到的數據進行訓練。模型包括多層LSTM、Dropout、BatchNormalization等層。
def create_complex_lstm_model(input_shape):
model = Sequential()
model.add(LSTM(256, input_shape=input_shape, return_sequences=True))
model.add(Dropout(0.3))
model.add(BatchNormalization())
model.add(LSTM(128, return_sequences=True))
model.add(Dropout(0.3))
model.add(BatchNormalization())
model.add(LSTM(64))
model.add(Dropout(0.3))
model.add(Dense(32, activation='relu'))
model.add(Dense(2)) # 輸出斑馬魚的未來位置(x, y)
model.compile(optimizer=Adam(learning_rate=0.0005), loss='mse')
return model
最後,將數據集拆分為訓練集和測試集,訓練LSTM模型並選擇最佳模型。
# 載入視頻幀
video_path = 'zebrafish_video.mp4'
frames = load_video_frames(video_path)
# 準備數據
lookback = 9
X, y = prepare_data(frames, lookback)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
# 模型訓練
input_shape = (X_train.shape[1], X_train.shape[2], 2) # lookback, number of fish, (x, y)
model = create_complex_lstm_model(input_shape)
history = model.fit(X_train, y_train, epochs=200, batch_size=32, validation_data=(X_val, y_val))
# 選擇最佳模型
val_loss = history.history['val_loss']
best_epoch = np.argmin(val_loss)
best_model = model
best_model.save(f'best_lstm_yolo_model_epoch_{best_epoch}.h5')
print(f"最佳模型儲存在第{best_epoch+1}輪訓練後,驗證集損失為{val_loss[best_epoch]:.4f}")
from ultralytics import YOLO
import cv2
import numpy as np
這一部分導入了所需的庫:
ultralytics.YOLO
是 YOLOv8 的一個實現庫,它提供了簡便的接口來加載和使用 YOLOv8 模型。cv2
是 OpenCV 庫,用於處理圖像和視頻。cv2.VideoCapture
可以用來讀取視頻幀。numpy
是一個常用的數據處理庫,用於進行數學運算和處理多維數組。# YOLOv8 模型初始化
yolo_model = YOLO('yolov8n.pt')
模型初始化:
yolo_model = YOLO('yolov8n.pt')
加載了 YOLOv8 的一個預訓練模型,這裡使用的是 yolov8n.pt
,即 YOLOv8 的 nano 版本,這是一個相對輕量級的模型,適合資源有限的設備進行推理。# 檢測斑馬魚的位置
def detect_fish(frame):
results = yolo_model(frame)
detections = []
for result in results.xyxy[0].cpu().numpy():
x_min, y_min, x_max, y_max, confidence, class_id = result
if class_id == 0: # 假設斑馬魚的class_id為0
x_center = (x_min + x_max) / 2
y_center = (y_min + y_max) / 2
detections.append((x_center, y_center))
return detections
目標檢測與位置計算:
results = yolo_model(frame)
這行代碼對單幀圖像進行目標檢測,YOLOv8 返回檢測結果,結果包括邊界框座標、置信度和類別ID。for result in results.xyxy[0].cpu().numpy()
這行提取 YOLO 檢測的每一個目標,results.xyxy[0]
代表檢測出的邊界框數據。(x_min, y_min, x_max, y_max)
是檢測出的邊界框的左上角和右下角座標。class_id
表示該目標的類別標識符,if class_id == 0
假設斑馬魚的類別標識符為0。(x_center, y_center)
是斑馬魚的中心位置,用來描述斑馬魚的座標。detections
返回這些中心點的座標列表。def prepare_data(frames, lookback=9):
X, y = []
for i in range(len(frames) - lookback):
input_seq = []
for j in range(lookback):
detected_fish = detect_fish(frames[i + j])
input_seq.append(detected_fish)
X.append(input_seq)
y.append(detect_fish(frames[i + lookback])) # 下一幀的真實位置
return np.array(X), np.array(y)
準備序列數據:
prepare_data
函數將一系列幀圖像轉換為 LSTM 模型的訓練數據。lookback=9
指定 LSTM 模型回顧的時間步長,也就是模型在預測時需要參考前多少幀的數據。for i in range(len(frames) - lookback)
遍歷每一個可能的序列,lookback
决定了能構成多少個序列。input_seq.append(detected_fish)
用 YOLOv8 檢測每一幀的斑馬魚位置,並將其添加到 input_seq
中。y.append(detect_fish(frames[i + lookback]))
則存儲下一幀的真實位置作為目標輸出,用來訓練 LSTM。def load_video_frames(video_path):
cap = cv2.VideoCapture(video_path)
frames = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frames.append(frame)
cap.release()
return frames
讀取視頻幀:
cap = cv2.VideoCapture(video_path)
使用 OpenCV 讀取指定路徑的視頻文件。while cap.isOpened()
確保視頻流被成功打開,cap.read()
讀取每一幀圖像。frames.append(frame)
將每一幀圖像存儲在列表 frames
中。from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout, BatchNormalization
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
導入LSTM相關庫:
Sequential
是 Keras 中用來構建模型的順序容器。LSTM
是長短期記憶神經網絡層,用於處理序列數據。Dense
是全連接層,用於將LSTM層的輸出映射到目標輸出。Dropout
是一種正則化技術,用於防止過擬合。BatchNormalization
是一種技術,用於加快訓練速度並穩定學習過程。Adam
是一種常用的優化器,用於調整神經網絡的權重。train_test_split
用於將數據集拆分為訓練集和驗證集。def create_complex_lstm_model(input_shape):
model = Sequential()
model.add(LSTM(256, input_shape=input_shape, return_sequences=True))
model.add(Dropout(0.3))
model.add(BatchNormalization())
model.add(LSTM(128, return_sequences=True))
model.add(Dropout(0.3))
model.add(BatchNormalization())
model.add(LSTM(64))
model.add(Dropout(0.3))
model.add(Dense(32, activation='relu'))
model.add(Dense(2)) # 輸出斑馬魚的未來位置(x, y)
model.compile(optimizer=Adam(learning_rate=0.0005), loss='mse')
return model
構建LSTM模型:
Sequential
容器: 定義一個順序模型,其中的層按順序排列。LSTM(256, return_sequences=True)
: 第一層LSTM有256個單元,input_shape
指定了輸入的形狀,return_sequences=True
意味著該層輸出的每個時間步驟都將作為下一層的輸入。return_sequences=True
,因此這層的輸出將成為後續層的單一輸入。x
和 y
坐標。# 載入視頻幀
video_path = 'zebrafish_video.mp4'
frames = load_video_frames(video_path)
載入視頻幀:
load_video_frames
函數載入給定視頻的所有幀,並將它們存儲在 frames
列表中。
# 準備數據
lookback = 9
X, y = prepare_data(frames, lookback)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
數據準備:
lookback = 9
定義了 LSTM 模型的回顧步長。X, y = prepare_data(frames, lookback)
使用先前的 prepare_data
函數來準備數據,X
是輸入序列,y
是對應的目標輸出。train_test_split
將數據集拆分為訓練集和驗證集,驗證集佔20%。# 模型訓練
input_shape = (X_train.shape[1], X_train.shape[2], 2) # lookback, number of fish, (x, y)
model = create_complex_lstm_model(input_shape)
history = model.fit(X_train, y_train, epochs=200, batch_size=32, validation_data=(X_val, y_val))
模型訓練:
input_shape = (X_train.shape[1], X_train.shape[2], 2)
確定LSTM模型的輸入形狀,其中 X_train.shape[1]
是回顧步長,X_train.shape[2]
是斑馬魚的數量,每個斑馬魚有兩個值 x
和 y
。model.fit
函數用來訓練模型,epochs=200
表示模型會在整個數據集上訓練200次,batch_size=32
表示每次訓練32個樣本,validation_data
用於在每個 epoch 結束後計算驗證集的損失。# 選擇最佳模型
val_loss = history.history['val_loss']
best_epoch = np.argmin(val_loss)
best_model = model
best_model.save(f'best_lstm_yolo_model_epoch_{best_epoch}.h5')
print(f"最佳模型儲存在第{best_epoch+1}輪訓練後,驗證集損失為{val_loss[best_epoch]:.4f}")
模型保存與選擇:
val_loss = history.history['val_loss']
獲取訓練過程中每個 epoch 的驗證損失值。best_epoch = np.argmin(val_loss)
找到驗證損失最小的 epoch,這個 epoch 對應最佳模型。best_model.save(f'best_lstm_yolo_model_epoch_{best_epoch}.h5')
將最佳模型保存為 .h5
文件,文件名包含最佳 epoch 的數字。print
函數輸出最佳模型所在的訓練輪次以及對應的驗證損失值。這段程式碼結合了YOLOv8的目標檢測能力和LSTM模型的時間序列預測能力,能夠對視頻中的斑馬魚進行行為分析。YOLOv8用於檢測斑馬魚的位置,而LSTM則學習這些位置的時間序列模式,從而預測未來的行為。這是一個相對複雜的深度學習應用,適用於動物行為研究中需要分析大量序列數據的場景。