1
+ import io
2
+ import os
3
+ import tempfile
4
+ from pathlib import Path
5
+ from typing import Dict , IO , Optional , Union , List
6
+ import numpy as np
7
+ import pandas as pd
8
+ import yaml
9
+ import csv
10
+ from evo .core import sync
11
+ from evo .core .trajectory import PoseTrajectory3D
12
+ from evo .core .trajectory import Plane
13
+ from evo .core .metrics import PoseRelation , Unit
14
+ from evo .tools import file_interface
15
+ import evo .main_ape as main_ape
16
+ import evo .main_rpe as main_rpe
17
+ class FileInterfaceException (Exception ):
18
+ pass
19
+
20
+ PathStrHandle = Union [str , Path , IO ]
21
+
22
+ def read_tum_trajectory_matrix (
23
+ source : PathStrHandle ,
24
+ delim : str = " " ,
25
+ comment_str : str = "#" ,
26
+ ) -> np .ndarray :
27
+ """
28
+ Read a TUM‑style trajectory (.tum) into an N×8 float array:
29
+ [timestamp, x, y, z, qx, qy, qz, qw]
30
+ """
31
+ raw_mat = csv_read_matrix (source , delim = delim , comment_str = comment_str )
32
+ # every row must have exactly 8 fields
33
+ if any (len (row ) != 8 for row in raw_mat ):
34
+ raise FileInterfaceException (
35
+ "TUM trajectory files must have 8 entries per row (no trailing delimiters)."
36
+ )
37
+ error_msg = ("TUM trajectory files must have 8 entries per row "
38
+ "and no trailing delimiter at the end of the rows (space)" )
39
+
40
+ try :
41
+ mat = np .array (raw_mat ).astype (float )
42
+ except ValueError :
43
+ raise FileInterfaceException (error_msg )
44
+ stamps = mat [:, 0 ] # n x 1
45
+ xyz = mat [:, 1 :4 ] # n x 3
46
+ quat = mat [:, 4 :] # n x 4
47
+ quat = np .roll (quat , 1 , axis = 1 ) # shift 1 column -> w in front column
48
+ return PoseTrajectory3D (xyz , quat , stamps )
49
+
50
+
51
+ def has_utf8_bom (path : Union [str , Path ]) -> bool :
52
+ """Return True if file starts with UTF‑8 BOM (0xEF,0xBB,0xBF)."""
53
+ with open (path , "rb" ) as f :
54
+ return f .read (3 ) == b"\xef \xbb \xbf "
55
+
56
+ # -----------------------------------------------------------------------------
57
+ # CSV → 2D list of strings
58
+ # -----------------------------------------------------------------------------
59
+ def csv_read_matrix (
60
+ file_path : PathStrHandle ,
61
+ delim : str = "," ,
62
+ comment_str : str = "#" ,
63
+ ) -> List [List [str ]]:
64
+ """
65
+ Read a CSV‑like file (or handle) into a 2D list of raw strings,
66
+ skipping any lines beginning with `comment_str`.
67
+ """
68
+ # file‑like case
69
+ if isinstance (file_path , io .IOBase ):
70
+ gen = (line for line in file_path if not line .startswith (comment_str ))
71
+ return [row for row in csv .reader (gen , delimiter = delim )]
72
+
73
+ # path case
74
+ p = Path (file_path )
75
+ if not p .is_file ():
76
+ raise FileInterfaceException (f"File does not exist: { p } " )
77
+ skip_bom = has_utf8_bom (p )
78
+ with open (p , "r" , encoding = "utf-8" ) as f :
79
+ if skip_bom :
80
+ f .seek (3 )
81
+ gen = (line for line in f if not line .startswith (comment_str ))
82
+ return [row for row in csv .reader (gen , delimiter = delim )]
83
+
84
+
85
+ __all__ = ["TrajectoryEvaluator" ]
86
+
87
+ # ---------------------------------------------------------------------------
88
+ # Defaults – feel free to override via *config*
89
+ # ---------------------------------------------------------------------------
90
+ DEFAULTS : Dict [str , object ] = {
91
+ "t_max_diff" : 0.02 ,
92
+ "t_offset" : 0.0 ,
93
+ "n_to_align" : - 1 ,
94
+ "delta" : 1 ,
95
+ "unit" : "m" , # "m" | "frame"
96
+ "correct_scale" : False ,
97
+ "project_to_plane" : "xyz" , # "xyz" | "xy"
98
+ # Filtering
99
+ "enable_covariance_based_removal" : False ,
100
+ "covariance_percentile_threshold" : 95 ,
101
+ "enable_no_motion_removal" : False ,
102
+ "distance_threshold" : 10.0 ,
103
+ "ap20_peak_rejection" : False ,
104
+ "ap20_peak_rejection_threshold" : 0.1 ,
105
+ "ap20_peak_rejection_trailing_window" : 5 ,
106
+ }
107
+
108
+ # ---------------------------------------------------------------------------
109
+ # Main class
110
+ # ---------------------------------------------------------------------------
111
+
112
+
113
+ class TrajectoryEvaluator :
114
+ """Tiny wrapper around *evo* that hides all boilerplate.
115
+
116
+ Parameters
117
+ ----------
118
+ reference, estimated : str | IO | None, optional
119
+ Paths to ``.tum`` files **or** already opened file‑like objects. Can be
120
+ left *None* and supplied later to :py:meth:`evaluate`.
121
+ config : dict | str | None, optional
122
+ Dict or path to a YAML/JSON config file. Missing keys fall back to
123
+ sane :pydata:`DEFAULTS`.
124
+
125
+ Example
126
+ -------
127
+ >>> ev = TrajectoryEvaluator("ref.tum", "est.tum", {"delta": 5})
128
+ >>> metrics = ev.evaluate() # {"ATE": ..., "RTE": ..., "last_error": ...}
129
+ """
130
+
131
+ # .....................................................................
132
+ # Construction helpers
133
+ # .....................................................................
134
+
135
+ def __init__ (
136
+ self ,
137
+ reference : Optional [Union [str , IO ]] = None ,
138
+ estimated : Optional [Union [str , IO ]] = None ,
139
+ config : Optional [Union [str , Dict ]] = None ,
140
+ ) -> None :
141
+ self ._reference_source = reference
142
+ self ._estimated_source = estimated
143
+ self .config : Dict [str , object ] = DEFAULTS .copy ()
144
+ if config is not None :
145
+ self .update_config (config )
146
+
147
+ # ..................................................................
148
+ # Public helpers
149
+ # ..................................................................
150
+
151
+ def update_config (self , cfg : Union [str , Dict ]) -> None :
152
+ """Merge *cfg* into the current configuration."""
153
+ if isinstance (cfg , str ):
154
+ with open (cfg , "r" , encoding = "utf-8" ) as f :
155
+ cfg = yaml .safe_load (f )
156
+ if not isinstance (cfg , dict ):
157
+ raise TypeError ("config must be dict or path to YAML/JSON file" )
158
+ self .config .update (cfg )
159
+
160
+ # ------------------------------------------------------------------
161
+ # Main entry point
162
+ # ------------------------------------------------------------------
163
+
164
+ # ------------------------------------------------------------------
165
+ # Internal implementation
166
+ # ------------------------------------------------------------------
167
+
168
+ # -- compute --------------------------------------------------------
169
+
170
+ def evaluate (self , traj_ref : str , traj_est : str ) -> Dict [str , float ]:
171
+ cfg = self .config
172
+ delta_unit_enum = Unit .meters if cfg ["unit" ] == "m" else Unit .frames
173
+ if cfg ["unit" ] not in ("m" , "frame" ):
174
+ raise ValueError ("config['unit'] must be 'm' or 'frame'" )
175
+
176
+ plane_param = None if cfg ["project_to_plane" ] == "xyz" else Plane .XY
177
+
178
+ self ._apply_filters (traj_ref , traj_est )
179
+
180
+ traj_ref , traj_est = sync .associate_trajectories (
181
+ traj_ref , traj_est , cfg ["t_max_diff" ], cfg ["t_offset" ]
182
+ )
183
+
184
+ ape = main_ape .ape (
185
+ traj_ref ,
186
+ traj_est ,
187
+ est_name = "estimated" ,
188
+ ref_name = "RTS" ,
189
+ pose_relation = PoseRelation .point_distance ,
190
+ align = True ,
191
+ align_origin = False ,
192
+ n_to_align = cfg ["n_to_align" ],
193
+ correct_scale = cfg ["correct_scale" ],
194
+ project_to_plane = plane_param ,
195
+ )
196
+
197
+ rpe = main_rpe .rpe (
198
+ traj_ref ,
199
+ traj_est ,
200
+ est_name = "estimated" ,
201
+ ref_name = "RTS" ,
202
+ pose_relation = PoseRelation .point_distance ,
203
+ delta = cfg ["delta" ],
204
+ delta_unit = delta_unit_enum ,
205
+ all_pairs = False ,
206
+ align = True ,
207
+ correct_scale = cfg ["correct_scale" ],
208
+ n_to_align = cfg ["n_to_align" ],
209
+ project_to_plane = plane_param ,
210
+ support_loop = False ,
211
+ )
212
+
213
+ return {
214
+ "ATE" : float (ape .stats ["rmse" ]),
215
+ "RTE" : float (rpe .stats ["rmse" ]),
216
+ "LE" : float (ape .np_arrays ["error_array" ][- 1 ]),
217
+ }
218
+
219
+ # -- filters --------------------------------------------------------
220
+
221
+ def _apply_filters (self , traj_ref , traj_est : str ):
222
+ cfg = self .config
223
+ if cfg ["enable_no_motion_removal" ]:
224
+ self ._filter_no_motion (traj_est , cfg ["distance_threshold" ])
225
+ if cfg ["ap20_peak_rejection" ]:
226
+ self ._filter_ap20 (
227
+ traj_ref ,
228
+ cfg ["ap20_peak_rejection_threshold" ],
229
+ cfg ["ap20_peak_rejection_trailing_window" ],
230
+ )
231
+
232
+ @staticmethod
233
+ def _filter_no_motion (traj_est , threshold : float ):
234
+ d = np .linalg .norm (traj_est .positions_xyz - traj_est .positions_xyz [0 ], axis = 1 )
235
+ traj_est .reduce_to_ids (np .where (d > threshold )[0 ])
236
+
237
+ @staticmethod
238
+ def _filter_ap20 (traj_ref , thresh : float , trailing : int ):
239
+ t_diff = np .diff (traj_ref .timestamps )
240
+ idx_remove = np .where (t_diff > thresh )[0 ]
241
+ extended = set (idx_remove )
242
+ for i in idx_remove :
243
+ extended .update (range (max (0 , i - trailing ), i + trailing ))
244
+ traj_ref .reduce_to_ids (np .setdiff1d (np .arange (len (traj_ref .positions_xyz )), list (extended )))
0 commit comments