Skip to content

Commit 508671a

Browse files
committed
update
Signed-off-by: wep21 <daisuke.nishimatsu1021@gmail.com>
1 parent f84ef97 commit 508671a

File tree

7 files changed

+171
-29
lines changed

7 files changed

+171
-29
lines changed

MODULE.bazel

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ register_toolchains("@rust_toolchains//:all")
2929
crate = use_extension("@rules_rust//crate_universe:extensions.bzl", "crate")
3030
crate.spec(
3131
package = "image",
32-
version = "0.24",
32+
version = "0.25.0",
3333
)
3434
crate.spec(
3535
features = ["ndarray"],
@@ -64,6 +64,26 @@ crate.spec(
6464
package = "rand",
6565
version = "0.8.5",
6666
)
67+
crate.spec(
68+
package = "ab_glyph",
69+
version = "0.2.29",
70+
)
71+
crate.spec(
72+
package = "imageproc",
73+
version = "0.25.0",
74+
)
75+
crate.annotation(
76+
build_script_env = dict(
77+
CARGO_PKG_AUTHORS = "",
78+
CARGO_PKG_DESCRIPTION = "",
79+
CARGO_PKG_HOMEPAGE = "",
80+
CARGO_PKG_LICENSE = "",
81+
CARGO_PKG_REPOSITORY = "",
82+
RUSTDOC = "",
83+
),
84+
crate = "rav1e",
85+
repositories = ["crates"],
86+
)
6787
crate.from_specs()
6888
use_repo(crate, "crates")
6989

detector/src/rtmo.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
11
#[cxx::bridge(namespace = "rtmo")]
22
pub mod ffi {
3-
#[derive(Debug)]
3+
#[derive(Debug, Clone)]
44
struct Point {
55
x: f32,
66
y: f32,
77
}
88

9-
#[derive(Debug)]
9+
#[derive(Debug, Clone)]
1010
struct Bbox {
1111
tl: Point,
1212
br: Point,
1313
score: f32,
1414
class_index: i32,
1515
}
1616

17-
#[derive(Debug)]
17+
#[derive(Debug, Clone)]
1818
struct Keypoint {
1919
x: f32,
2020
y: f32,
2121
score: f32,
2222
}
2323

24-
#[derive(Debug)]
24+
#[derive(Debug, Clone)]
2525
struct PoseResult {
2626
keypoints: Vec<Keypoint>,
2727
bbox: Bbox,

examples/BUILD.bazel

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@ rust_binary(
1919
],
2020
deps = [
2121
"//detector:rtmo",
22+
"//tracker:tracker",
2223
"@crates//:image",
24+
"@crates//:imageproc",
2325
"@crates//:minifb",
26+
"@crates//:ab_glyph",
2427
"@crates//:video-rs",
2528
"@cxx.rs//:cxx",
2629
],

examples/src/video_demo.rs

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
use ab_glyph::FontArc;
12
use cxx::let_cxx_string;
23
use cxx::CxxVector;
34
use image::{ImageBuffer, Rgb};
5+
use imageproc::drawing::draw_text_mut;
46
use minifb::{Key, Window, WindowOptions};
57
use rtmo::rtmo::ffi::{make_rtmo, PoseResult};
8+
use tracker::byte_tracker::{ByteTracker, PoseResultWithTrackID};
69
use video_rs::decode::Decoder;
710
use video_rs::location::Location;
811

@@ -23,6 +26,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
2326
(height as u32).try_into().unwrap(),
2427
);
2528
let mut detector = binding.pin_mut();
29+
let mut tracker = ByteTracker::new(12, 30, 0.5, 0.6, 0.8);
2630
let mut pose_results = CxxVector::<PoseResult>::new();
2731

2832
// Create a window for displaying frames
@@ -33,6 +37,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
3337
WindowOptions::default(),
3438
)?;
3539

40+
let font_data: &[u8] = include_bytes!("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf");
41+
let font = FontArc::try_from_slice(font_data).unwrap();
42+
let scale = 36.0f32;
43+
3644
for (_, frame) in decoder
3745
.decode_iter()
3846
.take_while(Result::is_ok)
@@ -49,20 +57,34 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
4957
let inference_time = start_time.elapsed();
5058
println!("Inference status: {}", status);
5159
println!("Inference time: {:?}", inference_time);
60+
let pose_results_with_track_id: Vec<PoseResultWithTrackID> = pose_results
61+
.iter()
62+
.map(|pose_result| PoseResultWithTrackID {
63+
track_id: None,
64+
pose: pose_result.clone(),
65+
})
66+
.collect();
67+
68+
let track_results = tracker
69+
.update(&pose_results_with_track_id)
70+
.unwrap_or_else(|e| {
71+
eprintln!("Error updating tracker: {}", e);
72+
vec![]
73+
});
5274

5375
// Draw inference results on the frame
5476
let mut annotated_image = ImageBuffer::from_fn(width as u32, height as u32, |x, y| {
5577
let pixel = rgb_image.get_pixel(x, y);
5678
*pixel
5779
});
5880

59-
annotated_image = pose_results
81+
annotated_image = track_results
6082
.iter()
61-
.fold(annotated_image, |mut img, pose_result| {
62-
let x_min = pose_result.bbox.tl.x as u32;
63-
let y_min = pose_result.bbox.tl.y as u32;
64-
let x_max = std::cmp::min(pose_result.bbox.br.x as u32, width - 1);
65-
let y_max = std::cmp::min(pose_result.bbox.br.y as u32, height - 1);
83+
.fold(annotated_image, |mut img, track_result| {
84+
let x_min = track_result.pose.bbox.tl.x as u32;
85+
let y_min = track_result.pose.bbox.tl.y as u32;
86+
let x_max = std::cmp::min(track_result.pose.bbox.br.x as u32, width - 1);
87+
let y_max = std::cmp::min(track_result.pose.bbox.br.y as u32, height - 1);
6688

6789
// Draw bounding box
6890
for x in x_min..=x_max {
@@ -75,7 +97,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
7597
}
7698

7799
// Draw keypoints
78-
for keypoint in &pose_result.keypoints {
100+
for keypoint in &track_result.pose.keypoints {
79101
let x = keypoint.x as u32;
80102
let y = keypoint.y as u32;
81103
if x < width && y < height {
@@ -102,6 +124,19 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
102124
}
103125
}
104126
}
127+
128+
// Draw track ID
129+
if let Some(track_id) = track_result.track_id {
130+
draw_text_mut(
131+
&mut img,
132+
Rgb([255, 255, 0]),
133+
x_min as i32,
134+
std::cmp::max(y_min as i32 - scale as i32, 0),
135+
scale,
136+
&font,
137+
&format!("{}", track_id),
138+
);
139+
}
105140
}
106141

107142
img

tracker/BUILD.bazel

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ rust_library(
55
srcs = glob([
66
"src/*.rs",
77
]),
8+
visibility = ["//visibility:public"],
89
deps = [
10+
"//detector:rtmo",
911
"@crates//:nalgebra",
1012
"@crates//:num",
1113
"@crates//:thiserror",
@@ -19,6 +21,7 @@ rust_test(
1921
"src/*.rs",
2022
]),
2123
deps = [
24+
"//detector:rtmo",
2225
"@crates//:nalgebra",
2326
"@crates//:nearly_eq",
2427
"@crates//:num",

tracker/src/byte_tracker.rs

Lines changed: 87 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use crate::{
22
error::ByteTrackError,
33
lapjv::lapjv,
4-
object::Object,
54
rect::Rect,
65
strack::{STrack, STrackState},
76
};
@@ -10,16 +9,94 @@ use std::{collections::HashMap, vec};
109
* ByteTracker
1110
* ---------------------------------------------------------------------------- */
1211

12+
#[derive(Debug, Clone)]
13+
pub struct PoseResultWithTrackID {
14+
pub pose: rtmo::rtmo::ffi::PoseResult,
15+
pub track_id: Option<usize>,
16+
}
17+
18+
impl PoseResultWithTrackID {
19+
pub fn new(pose: rtmo::rtmo::ffi::PoseResult, track_id: Option<usize>) -> Self {
20+
Self { pose, track_id }
21+
}
22+
pub fn get_rect(&self) -> Rect<f32> {
23+
Rect::<f32>::new(
24+
self.pose.bbox.tl.x,
25+
self.pose.bbox.tl.y,
26+
self.pose.bbox.br.x - self.pose.bbox.tl.x,
27+
self.pose.bbox.br.y - self.pose.bbox.tl.y,
28+
)
29+
}
30+
pub fn get_prob(&self) -> f32 {
31+
self.pose.bbox.score as f32
32+
}
33+
pub fn get_pose(&self) -> Vec<rtmo::rtmo::ffi::Keypoint> {
34+
self.pose.keypoints.clone()
35+
}
36+
}
37+
38+
impl From<STrack> for PoseResultWithTrackID {
39+
fn from(strack: STrack) -> Self {
40+
let rect = strack.get_rect();
41+
let pose = strack.get_pose();
42+
let score = strack.get_score();
43+
let track_id = strack.get_track_id();
44+
PoseResultWithTrackID {
45+
pose: rtmo::rtmo::ffi::PoseResult {
46+
keypoints: pose,
47+
bbox: rtmo::rtmo::ffi::Bbox {
48+
tl: rtmo::rtmo::ffi::Point {
49+
x: rect.x(),
50+
y: rect.y(),
51+
},
52+
br: rtmo::rtmo::ffi::Point {
53+
x: rect.x() + rect.width(),
54+
y: rect.y() + rect.height(),
55+
},
56+
score: score,
57+
class_index: 0, // Assuming class_index is not used
58+
},
59+
},
60+
track_id: Some(track_id),
61+
}
62+
}
63+
}
64+
65+
impl From<&STrack> for PoseResultWithTrackID {
66+
fn from(strack: &STrack) -> Self {
67+
let rect = strack.get_rect();
68+
let pose = strack.get_pose();
69+
let score = strack.get_score();
70+
let track_id = strack.get_track_id();
71+
PoseResultWithTrackID {
72+
pose: rtmo::rtmo::ffi::PoseResult {
73+
keypoints: pose,
74+
bbox: rtmo::rtmo::ffi::Bbox {
75+
tl: rtmo::rtmo::ffi::Point {
76+
x: rect.x(),
77+
y: rect.y(),
78+
},
79+
br: rtmo::rtmo::ffi::Point {
80+
x: rect.x() + rect.width(),
81+
y: rect.y() + rect.height(),
82+
},
83+
score: score,
84+
class_index: 0, // Assuming class_index is not used
85+
},
86+
},
87+
track_id: Some(track_id),
88+
}
89+
}
90+
}
91+
1392
#[derive(Debug)]
1493
pub struct ByteTracker {
1594
track_thresh: f32,
1695
high_thresh: f32,
1796
match_thresh: f32,
1897
max_time_lost: usize,
19-
2098
frame_id: usize,
2199
track_id_count: usize,
22-
23100
tracked_stracks: Vec<STrack>,
24101
lost_stracks: Vec<STrack>,
25102
removed_stracks: Vec<STrack>,
@@ -38,17 +115,18 @@ impl ByteTracker {
38115
high_thresh,
39116
match_thresh,
40117
max_time_lost: (track_buffer as f32 * frame_rate as f32 / 30.0) as usize,
41-
42118
frame_id: 0,
43119
track_id_count: 0,
44-
45120
tracked_stracks: Vec::new(),
46121
lost_stracks: Vec::new(),
47122
removed_stracks: Vec::new(),
48123
}
49124
}
50125

51-
pub fn update(&mut self, objects: &Vec<Object>) -> Result<Vec<Object>, ByteTrackError> {
126+
pub fn update(
127+
&mut self,
128+
poses: &Vec<PoseResultWithTrackID>,
129+
) -> Result<Vec<PoseResultWithTrackID>, ByteTrackError> {
52130
self.frame_id += 1;
53131

54132
/* ------------------ Step 1: Get detections ------------------------- */
@@ -57,9 +135,9 @@ impl ByteTracker {
57135
let mut det_stracks = Vec::new();
58136
let mut det_low_stracks = Vec::new();
59137

60-
for obj in objects {
61-
let strack = STrack::new(obj.get_rect(), obj.get_prob());
62-
if obj.get_prob() >= self.track_thresh {
138+
for pose in poses {
139+
let strack = STrack::new(pose.get_rect(), pose.get_pose(), pose.get_prob());
140+
if pose.get_prob() >= self.track_thresh {
63141
det_stracks.push(strack);
64142
} else {
65143
det_low_stracks.push(strack);

0 commit comments

Comments
 (0)