iT邦幫忙

2025 iThome 鐵人賽

DAY 10
0
Rust

30天解鎖 Rust 開發者工具箱系列 第 10

「Day 08」用 ort 打造 AI 物件辨識應用(三)

  • 分享至 

  • xImage
  •  

TL;DR 這三天忙死了,現在終於有時間好好彙整和重新修正這三天內容,今天這一篇也是帶修正。修正正在執行中...

架構設計

模組結構

src/
├── lib.rs              # 主要公開 API
├── main.rs             # 示範應用程式
├── model.rs            # YOLOv8 模型核心邏輯
├── ort_backend.rs      # ONNX Runtime 後端整合
├── yolo_result.rs      # 結果資料結構
├── config.rs           # 配置管理
├── nms.rs             # 非極大值抑制演算法
└── color.rs           # 可視化色彩工具

處理流程

graph TD
    A[輸入影像] --> B[前處理]
    B --> C[模型推論]
    C --> D[後處理]
    D --> E[NMS 過濾]
    E --> F[結果輸出]
  1. 前處理:調整影像尺寸、色彩空間轉換、標準化
  2. 模型推論:使用 ONNX Runtime 執行神經網路
  3. 後處理:解析模型輸出、座標轉換
  4. NMS 過濾:移除重疊的偵測框

實作範例

main function

fn main() -> Result<(), Box<dyn std::error::Error>> {
    println!("ORT Tutorial - Simple YOLO Video Processing Demonstration");
    println!("=======================================================");

    // Direct configuration - no JSON file needed
    let config = config::Config {
        model: "model_weight/onnx/version/yolo11s.onnx".to_string(),
        cuda: false,
        height: 640,
        width: 640,
        device_id: 0,
        batch: 1,
        batch_min: 1,
        batch_max: 32,
        profile: false,
        window_view: true,
        trt: false,
        fp16: false,
        nc: None,
        nk: None,
        nm: None,
        conf: 0.3,
        iou: 0.45,
        kconf: 0.55,
        plot: false,
    };

    println!("Configuration set directly in code");
    println!("Model: {}", config.model);
    println!("Video dimensions: {}x{}", config.width, config.height);
    println!("CUDA enabled: {}", config.cuda);

    // Run the simple video processing demonstration
    process_video_simple(&config)
}
fn process_video_simple(config: &config::Config) -> Result<(), Box<dyn std::error::Error>> {
    println!("Starting simple video processing demonstration");

    // Load the YOLO model
    let mut model = YOLOv8::new(config.clone())?;

    // Hardcoded video path
    let video_path = "data/test_short.mp4";
    println!("Processing video: {}", video_path);

    // Open video capture
    let mut cap = opencv::videoio::VideoCapture::from_file(video_path, opencv::videoio::CAP_ANY)?;
    if !cap.is_opened()? {
        return Err(format!("Could not open video file: {}", video_path).into());
    }

    let frame_width = cap.get(opencv::videoio::CAP_PROP_FRAME_WIDTH)? as i32;
    let frame_height = cap.get(opencv::videoio::CAP_PROP_FRAME_HEIGHT)? as i32;
    let fps = cap.get(opencv::videoio::CAP_PROP_FPS)?;
    println!(
        "Video dimensions: {}x{}, FPS: {}",
        frame_width, frame_height, fps
    );

    // Ensure output directory exists
    let output_dir = Path::new("output");
    if !output_dir.exists() {
        let _ = fs::create_dir_all(output_dir);
    }

    // Setup video writer for output
    let output_path = output_dir.join("ort_tutorial_output.mp4");
    let fourcc = opencv::videoio::VideoWriter::fourcc('m', 'p', '4', 'v')?;
    let mut writer = opencv::videoio::VideoWriter::new(
        output_path.to_str().unwrap(),
        fourcc,
        fps,
        opencv::core::Size::new(frame_width, frame_height),
        true,
    )?;

    if !writer.is_opened()? {
        println!("Warning: Could not create video writer. Output will not be saved.");
    }

    let mut frame_count = 0;
    let mut total_detections = 0;
    let start_time = std::time::Instant::now();

    println!("Processing video frames...");

    let mut frame = opencv::core::Mat::default();
    while cap.read(&mut frame)? {
        if frame.empty() {
            break;
        }
        frame_count += 1;

        // Clone frame for processing
        let mut display_frame = frame.clone();

        // Run inference on the frame
        let predictions = model.run_mat(&vec![frame.clone()])?;

        // Process detections
        if let Some(prediction) = predictions.get(0) {
            if let Some(bboxes) = &prediction.bboxes {
                for bbox in bboxes {
                    total_detections += 1;

                    // Get color for the class
                    let color = color::get_class_color(bbox.id() as i32);

                    // Draw bounding box
                    let pt1 = opencv::core::Point::new(bbox.xmin() as i32, bbox.ymin() as i32);
                    let pt2 = opencv::core::Point::new(
                        (bbox.xmin() + bbox.width()) as i32,
                        (bbox.ymin() + bbox.height()) as i32,
                    );

                    opencv::imgproc::rectangle(
                        &mut display_frame,
                        opencv::core::Rect::from_points(pt1, pt2),
                        color,
                        2,
                        opencv::imgproc::LINE_8,
                        0,
                    )?;

                    // Draw label
                    let class_name = model
                        .names
                        .get(bbox.id() as usize)
                        .map_or("unknown", |s| s.as_str());

                    let label = format!("{} {:.2}", class_name, bbox.confidence());
                    let font_scale = 0.5;
                    let thickness = 1;
                    let font_face = opencv::imgproc::FONT_HERSHEY_SIMPLEX;

                    // Get text size
                    let text_size = opencv::imgproc::get_text_size(
                        &label, font_face, font_scale, thickness, &mut 0,
                    )?;

                    // Draw background rectangle for text
                    opencv::imgproc::rectangle(
                        &mut display_frame,
                        opencv::core::Rect::new(
                            pt1.x,
                            pt1.y - text_size.height - 5,
                            text_size.width + 10,
                            text_size.height + 10,
                        ),
                        color,
                        -1, // filled
                        opencv::imgproc::LINE_8,
                        0,
                    )?;

                    // Draw text
                    opencv::imgproc::put_text(
                        &mut display_frame,
                        &label,
                        opencv::core::Point::new(pt1.x + 5, pt1.y - 5),
                        font_face,
                        font_scale,
                        opencv::core::Scalar::new(255.0, 255.0, 255.0, 0.0),
                        thickness,
                        opencv::imgproc::LINE_8,
                        false,
                    )?;
                }
            }
        }

        // Write frame to output video
        if writer.is_opened()? {
            writer.write(&display_frame)?;
        }

        // Display frame if window_view is enabled
        if config.window_view {
            opencv::highgui::imshow("ORT Tutorial - YOLO Detection", &display_frame)?;
            let key = opencv::highgui::wait_key(1)?;
            if key == 27 {
                // ESC key
                break;
            }
        }

        // Print progress every 30 frames
        if frame_count % 30 == 0 {
            println!(
                "Processed {} frames, total detections: {}",
                frame_count, total_detections
            );
        }
    }

    let elapsed = start_time.elapsed();
    println!("Processing complete!");
    println!("Total frames processed: {}", frame_count);
    println!("Total detections: {}", total_detections);
    println!("Processing time: {:.2}s", elapsed.as_secs_f64());
    println!(
        "Average FPS: {:.2}",
        frame_count as f64 / elapsed.as_secs_f64()
    );

    if writer.is_opened()? {
        println!("Output video saved to: {}", output_path.display());
    }

    Ok(())
}

lib.rs

#![allow(clippy::type_complexity)]

pub mod model;
pub mod ort_backend;
pub mod yolo_result;
pub mod config;
pub mod nms;

pub use crate::model::YOLOv8;
pub use crate::ort_backend::{Batch, OrtBackend, OrtConfig, OrtEP, YOLOTask};
pub use crate::yolo_result::{Bbox, Embedding, Point2, YOLOResult};
pub use crate::nms::non_max_suppression;

config.rs

#[derive(Debug, Clone)]
pub struct Config {
    pub model: String,
    pub cuda: bool,
    pub height: u32,
    pub width: u32,
    pub device_id: u32,
    pub batch: u32,
    pub batch_min: u32,
    pub batch_max: u32,
    pub profile: bool,
    pub window_view: bool,
    pub trt: bool,
    pub fp16: bool,
    pub nc: Option<u32>,
    pub nk: Option<u32>,
    pub nm: Option<u32>,
    pub conf: f32,
    pub iou: f32,
    pub kconf: f32,
    pub plot: bool,
}
#![allow(clippy::type_complexity)]
// YOLOv8 Model Processing - Advanced AI inference engine
// These warnings are suppressed to preserve planned architecture for future development

use anyhow::Result;
use image::{DynamicImage, ImageBuffer};
use ndarray::{s, Array, Axis, IxDyn};
use opencv::prelude::*;
use rand::{thread_rng, Rng};

use crate::{
    non_max_suppression, Batch, Bbox, Embedding, OrtBackend,
    OrtConfig, OrtEP, Point2, YOLOResult, YOLOTask,
};


pub struct YOLOv8 {
    // YOLOv8 model for all yolo-tasks
    pub engine: OrtBackend,
    pub nc: u32,
    pub nk: u32,
    pub nm: u32,
    pub height: u32,
    pub width: u32,
    pub batch: u32,
    pub task: YOLOTask,
    pub conf: f32,
    pub kconf: f32,
    pub iou: f32,
    pub names: Vec<String>,
    pub color_palette: Vec<(u8, u8, u8)>,
    pub profile: bool,
}

impl YOLOv8 {
    pub fn new(config: crate::config::Config) -> Result<Self> {
        // execution provider
        let ep = if config.trt {
            OrtEP::Trt(config.device_id)
        } else if config.cuda {
            OrtEP::Cuda(config.device_id)
        } else {
            OrtEP::Cpu
        };

        // batch
        let batch = Batch {
            opt: config.batch,
            min: config.batch_min,
            max: config.batch_max,
        };

        // build ort engine
        let ort_args = OrtConfig {
            ep,
            batch,
            f: config.model,
            task: None, // Auto-detect from model
            trt_fp16: config.fp16,
            image_size: (Some(config.height), Some(config.width)),
        };
        let engine = OrtBackend::build(ort_args)?;

        //  get batch, height, width, tasks, nc, nk, nm
        let (batch, height, width, task) = (
            engine.batch(),
            engine.height(),
            engine.width(),
            engine.task(),
        );
        let nc = engine.nc().or(config.nc).unwrap_or_else(|| {
            panic!("Failed to get num_classes, make it explicit with `--nc`");
        });
        let (nk, nm) = match task {
            YOLOTask::Pose => {
                let nk = engine.nk().or(config.nk).unwrap_or_else(|| {
                    panic!("Failed to get num_keypoints, make it explicit with `--nk`");
                });
                (nk, 0)
            }
            YOLOTask::Segment => {
                let nm = engine.nm().or(config.nm).unwrap_or_else(|| {
                    panic!("Failed to get num_masks, make it explicit with `--nm`");
                });
                (0, nm)
            }
            _ => (0, 0),
        };

        // class names
        let names = engine.names().unwrap_or(vec!["Unknown".to_string()]);

        // color palette
        let mut rng = thread_rng();
        let color_palette: Vec<_> = names
            .iter()
            .map(|_| {
                (
                    rng.gen_range(0..=255),
                    rng.gen_range(0..=255),
                    rng.gen_range(0..=255),
                )
            })
            .collect();

        Ok(Self {
            engine,
            names,
            conf: config.conf,
            kconf: config.kconf,
            iou: config.iou,
            color_palette,
            profile: config.profile,
            nc,
            nk,
            nm,
            height,
            width,
            batch,
            task,
        })
    }

    pub fn scale_wh(&self, w0: f32, h0: f32, w1: f32, h1: f32) -> (f32, f32, f32) {
        let r = (w1 / w0).min(h1 / h0);
        (r, (w0 * r).round(), (h0 * r).round())
    }



    pub fn run_mat(&mut self, xs: &Vec<opencv::core::Mat>) -> Result<Vec<YOLOResult>> {
        // pre-process
        let t_pre = std::time::Instant::now();

        // Convert cv::Mat to f32 array directly
        let mut xs_array =
            Array::ones((xs.len(), 3, self.height as usize, self.width as usize)).into_dyn();
        xs_array.fill(114.0 / 255.0); // Fill with gray padding color (114)

        let idx = 0; // Set idx to 0
        let mat = &xs[idx]; // Get the current Mat from the vector
        
        // Store original dimensions for scaling
        let original_width = mat.cols() as u32;
        let original_height = mat.rows() as u32;
        
        // Calculate scale to maintain aspect ratio (match preprocess method)
        let (_ratio, new_width, new_height) = self.scale_wh(
            original_width as f32,
            original_height as f32,
            self.width as f32,
            self.height as f32
        );
                            
        // Resize the image while maintaining aspect ratio
        let mut resized = opencv::core::Mat::default();
        opencv::imgproc::resize(
            mat,
            &mut resized,
            opencv::core::Size::new(new_width as i32, new_height as i32),
            0.0,
            0.0,
            opencv::imgproc::INTER_LINEAR,
        )?;

        // Convert BGR to RGB
        let bgr = resized.clone();
        let mut channels: opencv::core::Vector<opencv::core::Mat> = opencv::core::Vector::new();
        opencv::core::split(&bgr, &mut channels)?;
        
        // Swap B and R channels
        if channels.len() >= 3 {
            let temp = channels.get(0)?;
            channels.set(0, channels.get(2)?)?;
            channels.set(2, temp)?;
        }
        
        let mut rgb = opencv::core::Mat::default();
        opencv::core::merge(&channels, &mut rgb)?;

        // Calculate padding offsets to center the resized image
        let pad_left = ((self.width as f32 - new_width) / 2.0).floor() as i32;
        let pad_top = ((self.height as f32 - new_height) / 2.0).floor() as i32;
        
        // Copy resized image data into padded array
        for y in 0..new_height as i32 {
            for x in 0..new_width as i32 {
                if y < rgb.rows() && x < rgb.cols() {
                    let pixel = rgb.at_2d::<opencv::core::Vec3b>(y, x)?;
                    let target_y = (y + pad_top) as usize;
                    let target_x = (x + pad_left) as usize;
                    
                    if target_y < self.height as usize && target_x < self.width as usize {
                        xs_array[[idx, 0, target_y, target_x]] = pixel[0] as f32 / 255.0;
                        xs_array[[idx, 1, target_y, target_x]] = pixel[1] as f32 / 255.0;
                        xs_array[[idx, 2, target_y, target_x]] = pixel[2] as f32 / 255.0;
                    }
                }
            }
        }

        if self.profile {
            println!("[Model Preprocess]: {:?}", t_pre.elapsed());
        }

        // run
        let t_run = std::time::Instant::now();
        let ys = self.engine.run(xs_array, self.profile)?;
        if self.profile {
            println!("[Model Inference]: {:?}", t_run.elapsed());
        }

        // post-process
        let t_post = std::time::Instant::now();
        // Create dummy DynamicImage objects with original dimensions for postprocessing
        let dummy_images: Vec<DynamicImage> = xs
            .iter()
            .map(|mat| DynamicImage::new_rgb8(mat.cols() as u32, mat.rows() as u32))
            .collect();
        let ys = self.postprocess(ys, &dummy_images)?;
        if self.profile {
            println!("[Model Postprocess]: {:?}", t_post.elapsed());
        }

        Ok(ys)
    }


    pub fn postprocess(
        &self,
        xs: Vec<Array<f32, IxDyn>>,
        xs0: &[DynamicImage],
    ) -> Result<Vec<YOLOResult>> {
        if let YOLOTask::Classify = self.task {
            let mut ys = Vec::new();
            let preds = &xs[0];
            for batch in preds.axis_iter(Axis(0)) {
                ys.push(YOLOResult::new(
                    Some(Embedding::new(batch.into_owned())),
                    None,
                    None,
                    None,
                ));
            }
            Ok(ys)
        } else {
            const CXYWH_OFFSET: usize = 4; // cxcywh
            const KPT_STEP: usize = 3; // xyconf
            let preds = &xs[0];
            let protos = {
                if xs.len() > 1 {
                    Some(&xs[1])
                } else {
                    None
                }
            };
            let mut ys = Vec::new();
            for (idx, anchor) in preds.axis_iter(Axis(0)).enumerate() {
                // [bs, 4 + nc + nm, anchors]
                // input image
                let width_original = xs0[idx].width() as f32;
                let height_original = xs0[idx].height() as f32;
                
                // Calculate scaling ratio with improved precision
                let width_ratio = self.width as f32 / width_original;
                let height_ratio = self.height as f32 / height_original;
                let ratio = width_ratio.min(height_ratio);
                
                // Calculate actual dimensions used in the model (accounting for padding)
                let model_width = (width_original * ratio) as f32;
                let model_height = (height_original * ratio) as f32;
                
                // Calculate padding offsets (if any)
                let padx = (self.width as f32 - model_width) / 2.0;
                let pady = (self.height as f32 - model_height) / 2.0;

                // save each result
                let mut data: Vec<(Bbox, Option<Vec<Point2>>, Option<Vec<f32>>)> = Vec::new();
                for pred in anchor.axis_iter(Axis(1)) {
                    // split preds for different tasks
                    let bbox = pred.slice(s![0..CXYWH_OFFSET]);
                    let clss = pred.slice(s![CXYWH_OFFSET..CXYWH_OFFSET + self.nc as usize]);
                    let kpts = {
                        if let YOLOTask::Pose = self.task {
                            Some(pred.slice(s![pred.len() - KPT_STEP * self.nk as usize..]))
                        } else {
                            None
                        }
                    };
                    let coefs = {
                        if let YOLOTask::Segment = self.task {
                            Some(pred.slice(s![pred.len() - self.nm as usize..]).to_vec())
                        } else {
                            None
                        }
                    };

                    // confidence and id
                    let (id, &confidence) = clss
                        .into_iter()
                        .enumerate()
                        .reduce(|max, x| if x.1 > max.1 { x } else { max })
                        .unwrap(); // definitely will not panic!

                    // confidence filter
                    if confidence < self.conf {
                        continue;
                    }

                    // Adjust for padding if any
                    let cx = (bbox[0] - padx) / ratio;
                    let cy = (bbox[1] - pady) / ratio;
                    let w = bbox[2] / ratio;
                    let h = bbox[3] / ratio;
                    let x = cx - w / 2.;
                    let y = cy - h / 2.;
                    
                    // Ensure coordinates are within image bounds
                    let x_min = x.max(0.0).min(width_original);
                    let y_min = y.max(0.0).min(height_original);
                    let y_bbox = Bbox::new(
                        x_min,
                        y_min,
                        w.min(width_original - x_min),
                        h.min(height_original - y_min),
                        id,
                        confidence,
                    );

                    // kpts
                    let y_kpts = {
                        if let Some(kpts) = kpts {
                            let mut kpts_ = Vec::new();
                            // rescale keypoints with the same ratio
                            for i in 0..self.nk as usize {
                                let kx = (kpts[KPT_STEP * i] - padx) / ratio;
                                let ky = (kpts[KPT_STEP * i + 1] - pady) / ratio;
                                let kconf = kpts[KPT_STEP * i + 2];
                                if kconf < self.kconf {
                                    kpts_.push(Point2::default());
                                } else {
                                    kpts_.push(Point2::new_with_conf(
                                        kx.max(0.0f32).min(width_original),
                                        ky.max(0.0f32).min(height_original),
                                        kconf,
                                    ));
                                }
                            }
                            Some(kpts_)
                        } else {
                            None
                        }
                    };

                    // data merged
                    data.push((y_bbox, y_kpts, coefs));
                }

                // nms
                non_max_suppression(&mut data, self.iou);

                // decode
                let mut y_bboxes: Vec<Bbox> = Vec::new();
                let mut y_kpts: Vec<Vec<Point2>> = Vec::new();
                let mut y_masks: Vec<Vec<u8>> = Vec::new();
                for elem in data.into_iter() {
                    if let Some(kpts) = elem.1 {
                        y_kpts.push(kpts)
                    }

                    // decode masks
                    if let Some(coefs) = elem.2 {
                        let proto = protos.unwrap().slice(s![idx, .., .., ..]);
                        let (nm, nh, nw) = proto.dim();

                        // coefs * proto -> mask
                        let coefs = Array::from_shape_vec((1, nm), coefs)?; // (n, nm)
                        let proto = proto.to_owned().into_shape((nm, nh * nw))?; // (nm, nh*nw)
                        let mask = coefs.dot(&proto).into_shape((nh, nw, 1))?; // (nh, nw, n)

                        // build image from ndarray
                        let mask_im: ImageBuffer<image::Luma<_>, Vec<f32>> =
                            match ImageBuffer::from_raw(nw as u32, nh as u32, mask.into_raw_vec()) {
                                Some(image) => image,
                                None => panic!("can not create image from ndarray"),
                            };
                        let mut mask_im = image::DynamicImage::from(mask_im); // -> dyn

                        // rescale masks
                        let (_, w_mask, h_mask) =
                            self.scale_wh(width_original, height_original, nw as f32, nh as f32);
                        let mask_cropped = mask_im.crop(0, 0, w_mask as u32, h_mask as u32);
                        let mask_original = mask_cropped.resize_exact(
                            // resize_to_fill
                            width_original as u32,
                            height_original as u32,
                            match self.task {
                                YOLOTask::Segment => image::imageops::FilterType::CatmullRom,
                                _ => image::imageops::FilterType::Triangle,
                            },
                        );

                        // crop-mask with bbox
                        let mut mask_original_cropped = mask_original.into_luma8();
                        for y in 0..height_original as usize {
                            for x in 0..width_original as usize {
                                if x < elem.0.xmin() as usize
                                    || x > elem.0.xmax() as usize
                                    || y < elem.0.ymin() as usize
                                    || y > elem.0.ymax() as usize
                                {
                                    mask_original_cropped.put_pixel(
                                        x as u32,
                                        y as u32,
                                        image::Luma([0u8]),
                                    );
                                }
                            }
                        }
                        y_masks.push(mask_original_cropped.into_raw());
                    }
                    y_bboxes.push(elem.0);
                }

                // save each result
                let y = YOLOResult {
                    probs: None,
                    bboxes: if !y_bboxes.is_empty() {
                        Some(y_bboxes)
                    } else {
                        None
                    },
                    keypoints: if !y_kpts.is_empty() {
                        Some(y_kpts)
                    } else {
                        None
                    },
                    masks: if !y_masks.is_empty() {
                        Some(y_masks)
                    } else {
                        None
                    },
                };
                ys.push(y);
            }

            Ok(ys)
        }
    }


}


上一篇
「Day 08」用 ort 打造 AI 物件辨識應用(二)
下一篇
「Day 11」usls 包裝實現多種 AI 應用
系列文
30天解鎖 Rust 開發者工具箱14
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言