diff --git a/cflib/crazyflie/localization.py b/cflib/crazyflie/localization.py index 1d50d81d9..0d609789d 100644 --- a/cflib/crazyflie/localization.py +++ b/cflib/crazyflie/localization.py @@ -66,6 +66,7 @@ class Localization(): EXT_POSE_PACKED = 9 LH_ANGLE_STREAM = 10 LH_PERSIST_DATA = 11 + LH_MATCHED_ANGLE_STREAM = 12 def __init__(self, crazyflie=None): """ @@ -105,6 +106,8 @@ def _incoming(self, packet): decoded_data = bool(data[0]) elif pk_type == self.LH_ANGLE_STREAM: decoded_data = self._decode_lh_angle(data) + elif pk_type == self.LH_MATCHED_ANGLE_STREAM: + decoded_data = self._decode_matched_lh_angle(data) pk = LocalizationPacket(pk_type, data, decoded_data) self.receivedLocationPacket.call(pk) @@ -128,6 +131,27 @@ def _decode_lh_angle(self, data): return decoded_data + def _decode_matched_lh_angle(self, data): + decoded_data = {} + + raw_data = struct.unpack('> 4 + decoded_data['bs_count'] = raw_data[9] & 0x0F + + return decoded_data + def send_extpos(self, pos): """ Send the current Crazyflie X, Y, Z position. This is going to be diff --git a/cflib/crazyflie/platformservice.py b/cflib/crazyflie/platformservice.py index 262587a25..36421511c 100644 --- a/cflib/crazyflie/platformservice.py +++ b/cflib/crazyflie/platformservice.py @@ -42,6 +42,7 @@ PLATFORM_SET_CONT_WAVE = 0 PLATFORM_REQUEST_ARMING = 1 PLATFORM_REQUEST_CRASH_RECOVERY = 2 +PLATFORM_REQUEST_USER_NOTIFICATION = 3 VERSION_GET_PROTOCOL = 0 VERSION_GET_FIRMWARE = 1 @@ -110,6 +111,17 @@ def send_crash_recovery_request(self): pk.data = (PLATFORM_REQUEST_CRASH_RECOVERY, ) self._cf.send_packet(pk) + def send_user_notification(self, success: bool = True): + """ + Send a user notification to the Crazyflie. This is used to notify a user of some sort of event by using the + means available on the Crazyflie. + """ + pk = CRTPPacket() + pk.set_header(CRTPPort.PLATFORM, PLATFORM_COMMAND) + notification_type = 1 if success else 0 + pk.data = (PLATFORM_REQUEST_USER_NOTIFICATION, notification_type) + self._cf.send_packet(pk) + def get_protocol_version(self): """ Return version of the CRTP protocol diff --git a/cflib/localization/__init__.py b/cflib/localization/__init__.py index 6f4252c3e..0d873d0d7 100644 --- a/cflib/localization/__init__.py +++ b/cflib/localization/__init__.py @@ -23,8 +23,10 @@ from .lighthouse_bs_vector import LighthouseBsVector from .lighthouse_config_manager import LighthouseConfigFileManager from .lighthouse_config_manager import LighthouseConfigWriter +from .lighthouse_sweep_angle_reader import LighthouseMatchedSweepAngleReader from .lighthouse_sweep_angle_reader import LighthouseSweepAngleAverageReader from .lighthouse_sweep_angle_reader import LighthouseSweepAngleReader +from .lighthouse_utils import LighthouseCrossingBeam from .param_io import ParamFileManager __all__ = [ @@ -32,6 +34,8 @@ 'LighthouseBsVector', 'LighthouseSweepAngleAverageReader', 'LighthouseSweepAngleReader', + 'LighthouseMatchedSweepAngleReader', 'LighthouseConfigFileManager', 'LighthouseConfigWriter', - 'ParamFileManager'] + 'ParamFileManager', + 'LighthouseCrossingBeam'] diff --git a/cflib/localization/ippe_cf.py b/cflib/localization/ippe_cf.py index 51472dea0..c0bc306c1 100644 --- a/cflib/localization/ippe_cf.py +++ b/cflib/localization/ippe_cf.py @@ -65,7 +65,7 @@ def solve(U_cf: npt.ArrayLike, Q_cf: npt.ArrayLike) -> list[Solution]: First param: Y (positive to the left) Second param: Z (positive up) - :return: A list that contains 2 sets of pose solution from IPPE including rotation matrix + :return: A list that contains 2 sets of pose solutions from IPPE including rotation matrix translation matrix, and reprojection error. The first solution in the list has the smallest reprojection error. """ diff --git a/cflib/localization/lighthouse_bs_vector.py b/cflib/localization/lighthouse_bs_vector.py index 67e035964..41c9ed66a 100644 --- a/cflib/localization/lighthouse_bs_vector.py +++ b/cflib/localization/lighthouse_bs_vector.py @@ -25,6 +25,7 @@ import numpy as np import numpy.typing as npt +import yaml class LighthouseBsVector: @@ -137,6 +138,32 @@ def projection(self) -> npt.NDArray[np.float32]: def _q(self): return math.tan(self._lh_v1_vert_angle) / math.sqrt(1 + math.tan(self._lh_v1_horiz_angle) ** 2) + def __eq__(self, other): + if not isinstance(other, LighthouseBsVector): + return NotImplemented + + return (self._lh_v1_horiz_angle == other._lh_v1_horiz_angle and + self._lh_v1_vert_angle == other._lh_v1_vert_angle) + + @staticmethod + def yaml_representer(dumper, data: 'LighthouseBsVector'): + return dumper.represent_mapping('!LighthouseBsVector', { + 'lh_v1_angles': [data.lh_v1_horiz_angle, data.lh_v1_vert_angle], + }) + + @staticmethod + def yaml_constructor(loader, node): + values = loader.construct_mapping(node, deep=True) + lh_v1_angles = values.get('lh_v1_angles', [0.0, 0.0]) + if len(lh_v1_angles) != 2: + raise ValueError('lh_v1_angles must be a list of two angles') + lh_v1_horiz_angle, lh_v1_vert_angle = lh_v1_angles + return LighthouseBsVector(lh_v1_horiz_angle, lh_v1_vert_angle) + + +yaml.add_representer(LighthouseBsVector, LighthouseBsVector.yaml_representer) +yaml.add_constructor('!LighthouseBsVector', LighthouseBsVector.yaml_constructor) + class LighthouseBsVectors(list): """A list of 4 LighthouseBsVector, one for each sensor. @@ -144,7 +171,7 @@ class LighthouseBsVectors(list): def projection_pair_list(self) -> npt.NDArray: """ - Genereate a list of projection pairs for all vectors + Generate a list of projection pairs for all vectors """ result = np.empty((len(self), 2), dtype=float) for i, vector in enumerate(self): @@ -154,7 +181,7 @@ def projection_pair_list(self) -> npt.NDArray: def angle_list(self) -> npt.NDArray: """ - Genereate a list of angles for all vectors, the order is horizontal, vertical, horizontal, vertical... + Generate a list of angles for all vectors, the order is horizontal, vertical, horizontal, vertical... """ result = np.empty((len(self) * 2), dtype=float) for i, vector in enumerate(self): @@ -162,3 +189,18 @@ def angle_list(self) -> npt.NDArray: result[i * 2 + 1] = vector.lh_v1_vert_angle return result + + @staticmethod + def yaml_representer(dumper, data: 'LighthouseBsVectors'): + # Instead of using a sequence of LighthouseBsVector, we represent it as a sequence of lists to make it more + # compact + return dumper.represent_sequence('!LighthouseBsVectors', [list(vector.lh_v1_angle_pair) for vector in data]) + + @staticmethod + def yaml_constructor(loader, node): + values = loader.construct_sequence(node, deep=True) + return LighthouseBsVectors([LighthouseBsVector(pair[0], pair[1]) for pair in values]) + + +yaml.add_representer(LighthouseBsVectors, LighthouseBsVectors.yaml_representer) +yaml.add_constructor('!LighthouseBsVectors', LighthouseBsVectors.yaml_constructor) diff --git a/cflib/localization/lighthouse_cf_pose_sample.py b/cflib/localization/lighthouse_cf_pose_sample.py new file mode 100644 index 000000000..919dd3b72 --- /dev/null +++ b/cflib/localization/lighthouse_cf_pose_sample.py @@ -0,0 +1,109 @@ +from typing import NamedTuple + +import numpy as np +import numpy.typing as npt +import yaml + +from .ippe_cf import IppeCf +from cflib.localization.lighthouse_bs_vector import LighthouseBsVectors +from cflib.localization.lighthouse_types import Pose + +ArrayFloat = npt.NDArray[np.float_] + + +class BsPairPoses(NamedTuple): + """A type representing the poses of a pair of base stations""" + bs1: Pose + bs2: Pose + + +class LhCfPoseSample: + """ Represents a sample of a Crazyflie pose in space, it contains: + - a timestamp (if applicable) + - lighthouse angles from one or more base stations + - The the two solutions found by IPPE for each base station, in the cf ref frame. + + The ippe solution is somewhat heavy and is only created on demand by calling augment_with_ippe() + """ + + def __init__(self, angles_calibrated: dict[int, LighthouseBsVectors], timestamp: float = 0.0, + is_mandatory: bool = False) -> None: + self.timestamp: float = timestamp + + # Angles measured by the Crazyflie and compensated using calibration data + # Stored in a dictionary using base station id as the key + self.angles_calibrated: dict[int, LighthouseBsVectors] = angles_calibrated + + # A dictionary from base station id to BsPairPoses, The poses represents the two possible poses of the base + # stations found by IPPE, in the crazyflie reference frame. + self.ippe_solutions: dict[int, BsPairPoses] = {} + self.is_augmented = False + + # Some samples are mandatory and must not be removed, even if they appear to be outliers. For instance the + # the samples that define the origin or x-axis + self.is_mandatory = is_mandatory + + def augment_with_ippe(self, sensor_positions: ArrayFloat) -> None: + if not self.is_augmented: + self.ippe_solutions = self._find_ippe_solutions(self.angles_calibrated, sensor_positions) + self.is_augmented = True + + def is_empty(self) -> bool: + """Checks if no angles are set + + Returns: + bool: True if no angles are set + """ + return len(self.angles_calibrated) == 0 + + def _find_ippe_solutions(self, angles_calibrated: dict[int, LighthouseBsVectors], + sensor_positions: ArrayFloat) -> dict[int, BsPairPoses]: + + solutions: dict[int, BsPairPoses] = {} + for bs, angles in angles_calibrated.items(): + projections = angles.projection_pair_list() + estimates_ref_bs = IppeCf.solve(sensor_positions, projections) + estimates_ref_cf = self._convert_estimates_to_cf_reference_frame(estimates_ref_bs) + solutions[bs] = estimates_ref_cf + + return solutions + + def _convert_estimates_to_cf_reference_frame(self, estimates_ref_bs: list[IppeCf.Solution]) -> BsPairPoses: + """ + Convert the two ippe solutions from the base station reference frame to the CF reference frame + """ + rot_1 = estimates_ref_bs[0].R.transpose() + t_1 = np.dot(rot_1, -estimates_ref_bs[0].t) + + rot_2 = estimates_ref_bs[1].R.transpose() + t_2 = np.dot(rot_2, -estimates_ref_bs[1].t) + + return BsPairPoses(Pose(rot_1, t_1), Pose(rot_2, t_2)) + + def __eq__(self, other): + if not isinstance(other, LhCfPoseSample): + return NotImplemented + + return (self.timestamp == other.timestamp and + self.angles_calibrated == other.angles_calibrated and + self.is_mandatory == other.is_mandatory) + + @staticmethod + def yaml_representer(dumper, data: 'LhCfPoseSample'): + return dumper.represent_mapping('!LhCfPoseSample', { + 'timestamp': data.timestamp, + 'angles_calibrated': data.angles_calibrated, + 'is_mandatory': data.is_mandatory + }) + + @staticmethod + def yaml_constructor(loader, node): + values = loader.construct_mapping(node, deep=True) + timestamp = values.get('timestamp', 0.0) + angles_calibrated = values.get('angles_calibrated', {}) + is_mandatory = values.get('is_mandatory', False) + return LhCfPoseSample(angles_calibrated, timestamp, is_mandatory) + + +yaml.add_representer(LhCfPoseSample, LhCfPoseSample.yaml_representer) +yaml.add_constructor('!LhCfPoseSample', LhCfPoseSample.yaml_constructor) diff --git a/cflib/localization/lighthouse_geo_estimation_manager.py b/cflib/localization/lighthouse_geo_estimation_manager.py new file mode 100644 index 000000000..76eaa675e --- /dev/null +++ b/cflib/localization/lighthouse_geo_estimation_manager.py @@ -0,0 +1,553 @@ +# -*- coding: utf-8 -*- +# +# ,---------, ____ _ __ +# | ,-^-, | / __ )(_) /_______________ _____ ___ +# | ( O ) | / __ / / __/ ___/ ___/ __ `/_ / / _ \ +# | / ,--' | / /_/ / / /_/ /__/ / / /_/ / / /_/ __/ +# +------` /_____/_/\__/\___/_/ \__,_/ /___/\___/ +# +# Copyright (C) 2025 Bitcraze AB +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, in version 3. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +import copy +import datetime +import os +import pathlib +import threading +from typing import TextIO + +import numpy as np +import numpy.typing as npt +import yaml + +from cflib.localization.lighthouse_cf_pose_sample import LhCfPoseSample +from cflib.localization.lighthouse_geometry_solution import LighthouseGeometrySolution +from cflib.localization.lighthouse_geometry_solver import LighthouseGeometrySolver +from cflib.localization.lighthouse_initial_estimator import LighthouseInitialEstimator +from cflib.localization.lighthouse_system_aligner import LighthouseSystemAligner +from cflib.localization.lighthouse_system_scaler import LighthouseSystemScaler +from cflib.localization.lighthouse_types import LhBsCfPoses +from cflib.localization.lighthouse_utils import LighthouseCrossingBeam + + +ArrayFloat = npt.NDArray[np.float_] + + +class LhGeoEstimationManager(): + REFERENCE_DIST = 1.0 # Reference distance used for scaling the solution + + @classmethod + def align_and_scale_solution(cls, container: LhGeoInputContainerData, poses: LhBsCfPoses, + reference_distance: float) -> LhBsCfPoses: + start_idx_x_axis = 1 + start_idx_xy_plane = start_idx_x_axis + len(container.x_axis) + start_idx_xyz_space = start_idx_xy_plane + len(container.xy_plane) + + origin_pos = poses.cf_poses[0].translation + x_axis_poses = poses.cf_poses[start_idx_x_axis:start_idx_x_axis + len(container.x_axis)] + x_axis_pos = list(map(lambda x: x.translation, x_axis_poses)) + xy_plane_poses = poses.cf_poses[start_idx_xy_plane:start_idx_xyz_space] + xy_plane_pos = list(map(lambda x: x.translation, xy_plane_poses)) + + # Align the solution + bs_aligned_poses, trnsfrm = LighthouseSystemAligner.align(origin_pos, x_axis_pos, xy_plane_pos, poses.bs_poses) + cf_aligned_poses = list(map(trnsfrm.rotate_translate_pose, poses.cf_poses)) + + # Scale the solution + bs_scaled_poses, cf_scaled_poses, scale = LighthouseSystemScaler.scale_fixed_point(bs_aligned_poses, + cf_aligned_poses, + [reference_distance, 0, 0], + cf_aligned_poses[1]) + + return LhBsCfPoses(bs_poses=bs_scaled_poses, cf_poses=cf_scaled_poses) + + @classmethod + def estimate_geometry(cls, container: LhGeoInputContainerData) -> LighthouseGeometrySolution: + """Estimate the geometry of the system based on samples recorded by a Crazyflie""" + solution = LighthouseGeometrySolution() + + matched_samples = container.get_matched_samples() + solution.progress_info = 'Data validation' + validated_matched_samples = cls._data_validation(matched_samples, container, solution) + if solution.progress_is_ok: + solution.progress_info = 'Initial estimation of geometry' + initial_guess, cleaned_matched_samples = LighthouseInitialEstimator.estimate(validated_matched_samples, + solution) + solution.poses = initial_guess + if solution.progress_is_ok: + solution.progress_info = 'Refining geometry solution' + LighthouseGeometrySolver.solve(initial_guess, cleaned_matched_samples, container.sensor_positions, + solution) + solution.progress_info = 'Align and scale solution' + scaled_solution = cls.align_and_scale_solution(container, solution.poses, cls.REFERENCE_DIST) + solution.poses = scaled_solution + + cls._create_solution_stats(validated_matched_samples, solution) + + cls._humanize_error_info(solution, container) + + # TODO krri indicate in the solution if there is a geometry. progress_is_ok is not a good indicator + + return solution + + @classmethod + def _data_validation(cls, matched_samples: list[LhCfPoseSample], container: LhGeoInputContainerData, + solution: LighthouseGeometrySolution) -> list[LhCfPoseSample]: + """Validate the data collected by the Crazyflie and update the solution object with the results""" + + result = [] + + NO_DATA = 'No data' + TOO_FEW_BS = 'Too few base stations recorded' + + # Check the origin sample + origin = container.origin + if len(origin.angles_calibrated) == 0: + solution.append_mandatory_issue_sample(origin, NO_DATA) + elif len(origin.angles_calibrated) == 1: + solution.append_mandatory_issue_sample(origin, TOO_FEW_BS) + + # Check the x-axis samples + if len(container.x_axis) == 0: + solution.is_x_axis_samples_valid = False + solution.x_axis_samples_info = NO_DATA + solution.progress_is_ok = False + + if len(container.xy_plane) == 0: + solution.is_xy_plane_samples_valid = False + solution.xy_plane_samples_info = NO_DATA + solution.progress_is_ok = False + + if len(container.xyz_space) == 0: + solution.xyz_space_samples_info = NO_DATA + + # Samples must contain at least two base stations + for sample in matched_samples: + if sample == container.origin: + result.append(sample) + continue # The origin sample is already checked + + if len(sample.angles_calibrated) >= 2: + result.append(sample) + else: + # If the sample is mandatory, we cannot remove it, but we can add an issue to the solution + if sample.is_mandatory: + solution.append_mandatory_issue_sample(sample, TOO_FEW_BS) + else: + # If the sample is not mandatory, we can ignore it + solution.xyz_space_samples_info = 'Sample(s) with too few base stations skipped' + continue + + return result + + @classmethod + def _humanize_error_info(cls, solution: LighthouseGeometrySolution, container: LhGeoInputContainerData) -> None: + """Humanize the error info in the solution object""" + + # There might already be an error reported earlier, so only check if we think the sample is valid + if solution.is_origin_sample_valid: + solution.is_origin_sample_valid, solution.origin_sample_info = cls._error_info_for(solution, + [container.origin]) + if solution.is_x_axis_samples_valid: + solution.is_x_axis_samples_valid, solution.x_axis_samples_info = cls._error_info_for(solution, + container.x_axis) + if solution.is_xy_plane_samples_valid: + solution.is_xy_plane_samples_valid, solution.xy_plane_samples_info = cls._error_info_for(solution, + container.xy_plane) + + @classmethod + def _error_info_for(cls, solution: LighthouseGeometrySolution, samples: list[LhCfPoseSample]) -> tuple[bool, str]: + """Check if any issue sample is registered and return a human readable error message""" + info_strings = [] + for sample in samples: + for issue_sample, issue in solution.mandatory_issue_samples: + if sample == issue_sample: + info_strings.append(issue) + + if len(info_strings) > 0: + return False, ', '.join(info_strings) + else: + return True, '' + + @classmethod + def _create_solution_stats(cls, matched_samples: list[LhCfPoseSample], solution: LighthouseGeometrySolution): + """Calculate statistics about the solution and store them in the solution object""" + + # Estimated worst error for each sample based on crossing beams + estimated_errors: list[float] = [] + + for sample in matched_samples: + bs_ids = list(sample.angles_calibrated.keys()) + + bs_angle_list = [(solution.poses.bs_poses[bs_id], sample.angles_calibrated[bs_id]) for bs_id in bs_ids] + sample_error = LighthouseCrossingBeam.max_distance_all_permutations(bs_angle_list) + estimated_errors.append(sample_error) + + solution.error_stats = LighthouseGeometrySolution.ErrorStats( + mean=np.mean(estimated_errors), + max=np.max(estimated_errors), + std=np.std(estimated_errors) + ) + + class SolverThread(threading.Thread): + """This class runs the geometry solver in a separate thread. + It is used to provide continuous updates of the solution as well as updating the geometry in the Crazyflie. + """ + + def __init__(self, container: LhGeoInputContainer, is_done_cb) -> None: + """This constructor initializes the solver thread and starts it. + It takes a container with the input data and an callback that is called when the solution is done. + The thread will run the geometry solver and return the solution in the callback as soon as the data in the + container is modified. + Args: + container (LhGeoInputContainer): A container with the input data for the geometry estimation. + is_done_cb: Callback function that is called when the solution is done. + """ + threading.Thread.__init__(self, name='LhGeoEstimationManager.SolverThread') + self.daemon = True + + self.container = container + self.latest_solved_data_version = -1 + + self.is_done_cb = is_done_cb + + self.is_running = False + self.is_done = False + self.time_to_stop = False + + def run(self): + """Run the geometry solver in a separate thread""" + self.is_running = True + + with self.container.is_modified_condition: + while True: + if self.time_to_stop: + break + + if self.container.get_data_version() > self.latest_solved_data_version: + self.is_done = False + + # Copy the container as the original container may be modified while the solver is running + container_copy = self.container.get_data_copy() + solution = LhGeoEstimationManager.estimate_geometry(container_copy) + self.latest_solved_data_version = container_copy.version + + self.is_done = True + self.is_done_cb(solution) + + self.container.is_modified_condition.wait(timeout=0.1) + + self.is_running = False + + def stop(self, do_join: bool = True): + """Stop the solver thread""" + self.time_to_stop = True + if do_join: + # Wait for the thread to finish + if self.is_running: + self.join() + + +class LhGeoInputContainerData(): + EMPTY_POSE_SAMPLE = LhCfPoseSample(angles_calibrated={}) + + def __init__(self, sensor_positions: ArrayFloat, version: int = 0) -> None: + self.sensor_positions = sensor_positions + + self.origin: LhCfPoseSample = self.EMPTY_POSE_SAMPLE + self.x_axis: list[LhCfPoseSample] = [] + self.xy_plane: list[LhCfPoseSample] = [] + self.xyz_space: list[LhCfPoseSample] = [] + + # Used by LhGeoInputContainer to track changes in the data + self.version = version + + def get_matched_samples(self) -> list[LhCfPoseSample]: + """Get all pose samples collected in a list + + Returns: + list[LhCfPoseSample]: _description_ + """ + return [self.origin] + self.x_axis + self.xy_plane + self.xyz_space + + def is_empty(self) -> bool: + """Check if the container is empty, meaning no samples are set + + Returns: + bool: True if the container is empty, False otherwise + """ + return (len(self.x_axis) == 0 and + len(self.xy_plane) == 0 and + len(self.xyz_space) == 0 and + self.origin == self.EMPTY_POSE_SAMPLE) + + @staticmethod + def yaml_representer(dumper, data: LhGeoInputContainerData): + return dumper.represent_mapping('!LhGeoInputContainerData', { + 'origin': data.origin, + 'x_axis': data.x_axis, + 'xy_plane': data.xy_plane, + 'xyz_space': data.xyz_space, + 'sensor_positions': data.sensor_positions.tolist(), + }) + + @staticmethod + def yaml_constructor(loader, node): + values = loader.construct_mapping(node, deep=True) + sensor_positions = np.array(values['sensor_positions'], dtype=np.float_) + result = LhGeoInputContainerData(sensor_positions) + + result.origin = values['origin'] + result.x_axis = values['x_axis'] + result.xy_plane = values['xy_plane'] + result.xyz_space = values['xyz_space'] + + # Augment the samples with the sensor positions + result.origin.augment_with_ippe(sensor_positions) + + for sample in result.x_axis: + sample.augment_with_ippe(sensor_positions) + + for sample in result.xy_plane: + sample.augment_with_ippe(sensor_positions) + + for sample in result.xyz_space: + sample.augment_with_ippe(sensor_positions) + + return result + + +yaml.add_representer(LhGeoInputContainerData, LhGeoInputContainerData.yaml_representer) +yaml.add_constructor('!LhGeoInputContainerData', LhGeoInputContainerData.yaml_constructor) + + +class LhGeoInputContainer(): + """This class holds the input data required by the geometry estimation functionality. + """ + FILE_TYPE_VERSION = 1 + + def __init__(self, sensor_positions: ArrayFloat) -> None: + self._data = LhGeoInputContainerData(sensor_positions) + self.is_modified_condition = threading.Condition() + + self._session_name = None + self._session_path = os.getcwd() + self._auto_save = False + + def set_origin_sample(self, origin: LhCfPoseSample) -> None: + """Store/update the sample to be used for the origin + + Args: + origin (LhCfPoseSample): the new origin + """ + with self.is_modified_condition: + self._data.origin = origin + self._augment_sample(self._data.origin, True) + self._handle_data_modification() + + def set_x_axis_sample(self, x_axis: LhCfPoseSample) -> None: + """Store/update the sample to be used for the x_axis + + Args: + x_axis (LhCfPoseSample): the new x-axis sample + """ + with self.is_modified_condition: + self._data.x_axis = [x_axis] + self._augment_samples(self._data.x_axis, True) + self._handle_data_modification() + + def set_xy_plane_samples(self, xy_plane: list[LhCfPoseSample]) -> None: + """Store/update the samples to be used for the xy-plane + + Args: + xy_plane (list[LhCfPoseSample]): the new xy-plane samples + """ + with self.is_modified_condition: + self._data.xy_plane = xy_plane + self._augment_samples(self._data.xy_plane, True) + self._handle_data_modification() + + def append_xy_plane_sample(self, xy_plane: LhCfPoseSample) -> None: + """append to the samples to be used for the xy-plane + + Args: + xy_plane (LhCfPoseSample): the new xy-plane sample + """ + with self.is_modified_condition: + self._augment_sample(xy_plane, True) + self._data.xy_plane.append(xy_plane) + self._handle_data_modification() + + def xy_plane_sample_count(self) -> int: + """Get the number of samples in the xy-plane + + Returns: + int: The number of samples in the xy-plane + """ + with self.is_modified_condition: + return len(self._data.xy_plane) + + def set_xyz_space_samples(self, samples: list[LhCfPoseSample]) -> None: + """Store/update the samples for the volume + + Args: + samples (list[LhMeasurement]): the new samples + """ + new_samples = samples + self._augment_samples(new_samples, False) + with self.is_modified_condition: + self._data.xyz_space = [] + self.append_xyz_space_samples(new_samples) + self._handle_data_modification() + + def append_xyz_space_samples(self, samples: list[LhCfPoseSample]) -> None: + """Append to the samples for the volume + + Args: + samples (LhMeasurement): the new samples + """ + new_samples = samples + self._augment_samples(new_samples, False) + with self.is_modified_condition: + self._data.xyz_space += new_samples + self._handle_data_modification() + + def xyz_space_sample_count(self) -> int: + """Get the number of samples in the xyz space + + Returns: + int: The number of samples in the xyz space + """ + with self.is_modified_condition: + return len(self._data.xyz_space) + + def clear_all_samples(self) -> None: + """Clear all samples in the container""" + self._set_new_data_container(LhGeoInputContainerData(self._data.sensor_positions)) + + def get_data_version(self) -> int: + """Get the current data version + + Returns: + int: The current data version + """ + with self.is_modified_condition: + return self._data.version + + def get_data_copy(self) -> LhGeoInputContainerData: + """Get a copy of the data in the container + + Returns: + LhGeoInputContainerData: A copy of the data in the container + """ + with self.is_modified_condition: + return copy.deepcopy(self._data) + + def is_empty(self) -> bool: + """Check if the container is empty + + Returns: + bool: True if the container is empty, False otherwise + """ + with self.is_modified_condition: + return self._data.is_empty() + + def save_as_yaml_file(self, text_io: TextIO): + """Save the data container as a YAML file + + Args: + text_io (TextIO): The text IO stream to write the YAML data to + """ + with self.is_modified_condition: + self.save_data_container_as_yaml(self._data, text_io) + + @classmethod + def save_data_container_as_yaml(cls, container_data: LhGeoInputContainerData, text_io: TextIO): + """Save the data container as a YAML string suitable for saving to a file + + Args: + container_data (LhGeoInputContainerData): The data container to save + text_io (TextIO): The text IO stream to write the YAML data to + """ + file_data = { + 'file_type_version': cls.FILE_TYPE_VERSION, + 'data': container_data + } + yaml.dump(file_data, text_io, default_flow_style=False) + + def populate_from_file_yaml(self, text_io: TextIO) -> None: + """Load the data from file + + Args: + text_io (TextIO): The text IO stream to read the YAML data from + Raises: + ValueError: If the file type version is not supported + """ + file_yaml = yaml.load(text_io, Loader=yaml.FullLoader) + if file_yaml['file_type_version'] != self.FILE_TYPE_VERSION: + raise ValueError(f'Unsupported file type version: {file_yaml["file_type_version"]}') + self._set_new_data_container(file_yaml['data']) + + def enable_auto_save(self, session_path: str = os.getcwd()) -> None: + """Enable auto-saving of the session data to a file in the specified path. + Session files will be named with the current date and time. + + Args: + session_path (str): The path to save the session data to. Defaults to the current working directory. + """ + self._session_path = session_path + self._auto_save = True + + def _set_new_data_container(self, new_data: LhGeoInputContainerData) -> None: + """Set a new data container and update the version""" + + # Maintain version + with self.is_modified_condition: + current_version = self._data.version + self._data = new_data + self._data.version = current_version + + self._new_session() + self._handle_data_modification() + + def _augment_sample(self, sample: LhCfPoseSample, is_mandatory: bool) -> None: + sample.augment_with_ippe(self._data.sensor_positions) + sample.is_mandatory = is_mandatory + + def _augment_samples(self, samples: list[LhCfPoseSample], is_mandatory: bool) -> None: + for sample in samples: + self._augment_sample(sample, is_mandatory) + + def _handle_data_modification(self) -> None: + """Update the data version and notify the waiting thread""" + with self.is_modified_condition: + self._data.version += 1 + self.is_modified_condition.notify() + + self._save_session() + + def _save_session(self) -> None: + if self._auto_save and not self.is_empty(): + if self._session_name is None: + self._session_name = datetime.datetime.now().isoformat(timespec='seconds') + + file_name = os.path.join(self._session_path, f'lh_geo_{self._session_name}.yaml') + pathlib.Path(self._session_path).mkdir(parents=True, exist_ok=True) + with open(file_name, 'w', encoding='UTF8') as handle: + self.save_as_yaml_file(handle) + + def _new_session(self) -> None: + """Start a new session""" + self._session_name = None diff --git a/cflib/localization/lighthouse_geometry_solution.py b/cflib/localization/lighthouse_geometry_solution.py new file mode 100644 index 000000000..bbc531365 --- /dev/null +++ b/cflib/localization/lighthouse_geometry_solution.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# || ____ _ __ +# +------+ / __ )(_) /_______________ _____ ___ +# | 0xBC | / __ / / __/ ___/ ___/ __ `/_ / / _ \ +# +------+ / /_/ / / /_/ /__/ / / /_/ / / /_/ __/ +# || || /_____/_/\__/\___/_/ \__,_/ /___/\___/ +# +# Copyright (C) 2025 Bitcraze AB +# +# Crazyflie Nano Quadcopter Client +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU General Public License +# as published by the Free Software Foundation; either version 2 +# of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from collections import namedtuple + +from cflib.localization.lighthouse_cf_pose_sample import LhCfPoseSample +from cflib.localization.lighthouse_types import LhBsCfPoses + + +class LighthouseGeometrySolution: + """ + A class to represent the solution of a lighthouse geometry problem. + """ + + ErrorStats = namedtuple('ErrorStats', ['mean', 'max', 'std']) + + def __init__(self): + # The estimated poses of the base stations and the CF samples + self.poses = LhBsCfPoses(bs_poses={}, cf_poses=[]) + + # Information about errors in the solution + self.error_stats = self.ErrorStats(0.0, 0.0, 0.0) + + # Indicates if the solution converged (True). + # If it did not converge, the solution is possibly not good enough to use + self.has_converged = False + + # Progress information stating how far in the solution process we got + self.progress_info = '' + + # Indicates that all previous steps in the solution process were successful and that the next step + # can be executed. This is used to determine if the solution process can continue. + self.progress_is_ok = True + + # Issue descriptions + self.is_origin_sample_valid = True + self.origin_sample_info = '' + self.is_x_axis_samples_valid = True + self.x_axis_samples_info = '' + self.is_xy_plane_samples_valid = True + self.xy_plane_samples_info = '' + # For the xyz space, there are not any stopping errors, this string may contain information for the user though + self.xyz_space_samples_info = '' + + # Samples that are mandatory for the solution but where problems were encountered. The tuples contain the sample + # and a description of the issue. This list is used to extract issue descriptions for the user interface. + self.mandatory_issue_samples: list[tuple[LhCfPoseSample, str]] = [] + + # General failure information if the problem is not related to a specific sample + self.general_failure_info = '' + + # The number of links between base stations. The data is organized as a dictionary with base station ids as + # keys, mapped to a dictionary of base station ids and the number of links to other base stations. + # For example: link_count[1][2] = 3 means that base station 1 has 3 links to base station 2. + self.link_count: dict[int, dict[int, int]] = {} + + def append_mandatory_issue_sample(self, sample: LhCfPoseSample, issue: str): + """ + Append a sample with an issue to the list of mandatory issue samples. + + :param sample: The CF pose sample that has an issue. + :param issue: A description of the issue with the sample. + """ + self.mandatory_issue_samples.append((sample, issue)) + self.progress_is_ok = False diff --git a/cflib/localization/lighthouse_geometry_solver.py b/cflib/localization/lighthouse_geometry_solver.py index 949d70b50..ffa0e3e6e 100644 --- a/cflib/localization/lighthouse_geometry_solver.py +++ b/cflib/localization/lighthouse_geometry_solver.py @@ -25,12 +25,13 @@ import numpy.typing as npt import scipy.optimize +from cflib.localization.lighthouse_cf_pose_sample import LhCfPoseSample +from cflib.localization.lighthouse_cf_pose_sample import Pose +from cflib.localization.lighthouse_geometry_solution import LighthouseGeometrySolution from cflib.localization.lighthouse_types import LhBsCfPoses -from cflib.localization.lighthouse_types import LhCfPoseSample -from cflib.localization.lighthouse_types import Pose -class LighthouseGeometrySolution: +class SolverData: """ Represents a solution from the geometry solver. @@ -45,7 +46,7 @@ def __init__(self) -> None: # Nr of base stations self.n_bss: int = None - # Nr of parametrs per base station + # Nr of parameters per base station self.n_params_per_bs = self.len_pose # Nr of sampled Crazyflie poses @@ -64,24 +65,6 @@ def __init__(self) -> None: self.bs_id_to_index: dict[int, int] = {} self.bs_index_to_id: dict[int, int] = {} - # The solution ###################### - - # The estimated poses of the base stations - self.bs_poses: dict[int, Pose] = {} - - # The estimated poses of the CF samples - self.cf_poses: list[Pose] = [] - - # Estimated error for each base station in each sample - self.estimated_errors: list[dict[int, float]] = [] - - # Information about errors in the solution - self.error_info = {} - - # Indicates if the solution converged (True). - # If it did not converge, the solution is probably not good enough to use - self.success = False - class LighthouseGeometrySolver: """ @@ -135,34 +118,35 @@ class LighthouseGeometrySolver: @classmethod def solve(cls, initial_guess: LhBsCfPoses, matched_samples: list[LhCfPoseSample], - sensor_positions: npt.ArrayLike) -> LighthouseGeometrySolution: + sensor_positions: npt.ArrayLike, solution: LighthouseGeometrySolution) -> None: """ Solve for the pose of base stations and CF samples. The pose of the CF in sample 0 defines the global reference frame. - Iteration is terminated acceptable solution is found. If no solution is found after a fixed number of iterations - the solver is terminated. The success member of the result will indicate if a solution was found or not. + Iteration is terminated when an acceptable solution is found. If no solution is found after a fixed number of + iterations the solver is terminated. The has_converged member of the result will indicate if a solution was + found or not. Note: the solution may still be good enough to use even if it did not converge. :param initial_guess: Initial guess for the base stations and CF sample poses :param matched_samples: List of matched samples. :param sensor_positions: Sensor positions (3D), in the CF reference frame - :return: an instance of LighthouseGeometrySolution + :param solution: an instance of LighthouseGeometrySolution that is filled with the result """ - solution = LighthouseGeometrySolution() + defs = SolverData() - solution.n_bss = len(initial_guess.bs_poses) - solution.n_cfs = len(matched_samples) - solution.n_cfs_in_params = len(matched_samples) - 1 - solution.n_sensors = len(sensor_positions) - solution.bs_id_to_index, solution.bs_index_to_id = cls._create_bs_map(initial_guess.bs_poses) + defs.n_bss = len(initial_guess.bs_poses) + defs.n_cfs = len(matched_samples) + defs.n_cfs_in_params = len(matched_samples) - 1 + defs.n_sensors = len(sensor_positions) + defs.bs_id_to_index, defs.bs_index_to_id = cls._create_bs_map(initial_guess.bs_poses) target_angles = cls._populate_target_angles(matched_samples) idx_agl_pr_to_bs, idx_agl_pr_to_cf, idx_agl_pr_to_sens_pos, jac_sparsity = cls._populate_indexes_and_jacobian( - matched_samples, solution) - params_bs, params_cfs = cls._populate_initial_guess(initial_guess, solution) + matched_samples, defs) + params_bs, params_cfs = cls._populate_initial_guess(initial_guess, defs) # Extra arguments passed on to calc_residual() - args = (solution, idx_agl_pr_to_bs, idx_agl_pr_to_cf, idx_agl_pr_to_sens_pos, target_angles, sensor_positions) + args = (defs, idx_agl_pr_to_bs, idx_agl_pr_to_cf, idx_agl_pr_to_sens_pos, target_angles, sensor_positions) # Vector to optimize. Composed of base station parameters followed by cf parameters x0 = np.hstack((params_bs.ravel(), params_cfs.ravel())) @@ -174,11 +158,10 @@ def solve(cls, initial_guess: LhBsCfPoses, matched_samples: list[LhCfPoseSample] x_scale='jac', ftol=1e-8, method='trf', - max_nfev=solution.max_nr_iter, + max_nfev=defs.max_nr_iter, args=args) - cls._condense_results(result, solution, matched_samples) - return solution + cls._condense_results(result, defs, matched_samples, solution) @classmethod def _populate_target_angles(cls, matched_samples: list[LhCfPoseSample]) -> npt.NDArray: @@ -193,7 +176,7 @@ def _populate_target_angles(cls, matched_samples: list[LhCfPoseSample]) -> npt.N return np.array(result) @classmethod - def _populate_indexes_and_jacobian(cls, matched_samples: list[LhCfPoseSample], defs: LighthouseGeometrySolution + def _populate_indexes_and_jacobian(cls, matched_samples: list[LhCfPoseSample], defs: SolverData ) -> tuple[npt.NDArray, npt.NDArray, npt.NDArray, npt.NDArray]: """ To speed up calculations all operations in the iteration phase are done on np.arrays of equal length (ish), @@ -254,7 +237,7 @@ def _populate_indexes_and_jacobian(cls, matched_samples: list[LhCfPoseSample], d @classmethod def _populate_initial_guess(cls, initial_guess: LhBsCfPoses, - defs: LighthouseGeometrySolution) -> tuple[npt.NDArray, npt.NDArray]: + defs: SolverData) -> tuple[npt.NDArray, npt.NDArray]: """ Generate parameters for base stations and CFs, this is the initial guess we start to iterate from. """ @@ -270,7 +253,7 @@ def _populate_initial_guess(cls, initial_guess: LhBsCfPoses, return params_bs, params_cfs @classmethod - def _params_to_struct(cls, params, defs: LighthouseGeometrySolution): + def _params_to_struct(cls, params, defs: SolverData): """ Convert the params list to two arrays, one for base stations and one for CFs """ @@ -282,7 +265,7 @@ def _params_to_struct(cls, params, defs: LighthouseGeometrySolution): return params_bs_poses, params_cf_poses @classmethod - def _calc_residual(cls, params, defs: LighthouseGeometrySolution, index_angle_pair_to_bs, index_angle_pair_to_cf, + def _calc_residual(cls, params, defs: SolverData, index_angle_pair_to_bs, index_angle_pair_to_cf, index_angle_pair_to_sensor_base, target_angles, sensor_positions): """ Calculate the residual for a set of parameters. The residual is defined as the distance from a sensor to the @@ -317,13 +300,13 @@ def _calc_residual(cls, params, defs: LighthouseGeometrySolution, index_angle_pa @classmethod def _poses_to_angle_pairs(cls, bss, cf_poses, sensor_base_pos, index_angle_pair_to_bs, index_angle_pair_to_cf, - index_angle_pair_to_sensor_base, defs: LighthouseGeometrySolution): + index_angle_pair_to_sensor_base, defs: SolverData): pairs = cls._calc_angle_pairs(bss[index_angle_pair_to_bs], cf_poses[index_angle_pair_to_cf], sensor_base_pos[index_angle_pair_to_sensor_base], defs) return pairs @classmethod - def _calc_angle_pairs(cls, bs_p_a, cf_p_a, sens_pos_p_a, defs: LighthouseGeometrySolution): + def _calc_angle_pairs(cls, bs_p_a, cf_p_a, sens_pos_p_a, defs: SolverData): """ Calculate angle pairs based on base station poses, cf poses and sensor positions @@ -368,7 +351,7 @@ def _pose_to_params(cls, pose: Pose) -> npt.NDArray: return np.concatenate((pose.rot_vec, pose.translation)) @classmethod - def _params_to_pose(cls, params: npt.ArrayLike, defs: LighthouseGeometrySolution) -> Pose: + def _params_to_pose(cls, params: npt.ArrayLike, defs: SolverData) -> Pose: """ Convert from the array format used in the solver to Pose """ @@ -396,56 +379,19 @@ def _create_bs_map(cls, initial_guess_bs_poses: dict[int, Pose]) -> tuple[dict[i return bs_id_to_index, bs_index_to_id @classmethod - def _condense_results(cls, lsq_result, solution: LighthouseGeometrySolution, - matched_samples: list[LhCfPoseSample]) -> None: - bss, cf_poses = cls._params_to_struct(lsq_result.x, solution) + def _condense_results(cls, lsq_result, defs: SolverData, matched_samples: list[LhCfPoseSample], + solution: LighthouseGeometrySolution) -> None: + bss, cf_poses = cls._params_to_struct(lsq_result.x, defs) # Extract CF pose estimates # First pose (origin) is not in the parameter list - solution.cf_poses.append(Pose()) + solution.poses.cf_poses.append(Pose()) for i in range(len(matched_samples) - 1): - solution.cf_poses.append(cls._params_to_pose(cf_poses[i], solution)) + solution.poses.cf_poses.append(cls._params_to_pose(cf_poses[i], defs)) # Extract base station pose estimates for index, pose in enumerate(bss): - bs_id = solution.bs_index_to_id[index] - solution.bs_poses[bs_id] = cls._params_to_pose(pose, solution) + bs_id = defs.bs_index_to_id[index] + solution.poses.bs_poses[bs_id] = cls._params_to_pose(pose, defs) - solution.success = lsq_result.success - - # Extract the error for each CF pose - residuals = lsq_result.fun - i = 0 - for sample in matched_samples: - sample_errors = {} - for bs_id in sorted(sample.angles_calibrated.keys()): - sample_errors[bs_id] = np.linalg.norm(residuals[i:i + 2]) - i += solution.n_sensors * 2 - solution.estimated_errors.append(sample_errors) - - solution.error_info = cls._aggregate_error_info(solution.estimated_errors) - - @classmethod - def _aggregate_error_info(cls, estimated_errors: list[dict[int, float]]) -> dict[str, float]: - error_per_bs = {} - errors = [] - for sample_errors in estimated_errors: - for bs_id, error in sample_errors.items(): - if bs_id not in error_per_bs: - error_per_bs[bs_id] = [] - error_per_bs[bs_id].append(error) - errors.append(error) - - error_info = {} - error_info['mean_error'] = np.mean(errors) - error_info['max_error'] = np.max(errors) - error_info['std_error'] = np.std(errors) - - error_info['bs'] = {} - for bs_id, errors in error_per_bs.items(): - error_info['bs'][bs_id] = {} - error_info['bs'][bs_id]['mean_error'] = np.mean(errors) - error_info['bs'][bs_id]['max_error'] = np.max(errors) - error_info['bs'][bs_id]['std_error'] = np.std(errors) - - return error_info + solution.has_converged = lsq_result.success diff --git a/cflib/localization/lighthouse_initial_estimator.py b/cflib/localization/lighthouse_initial_estimator.py index 1853d415a..6be1a2fca 100644 --- a/cflib/localization/lighthouse_initial_estimator.py +++ b/cflib/localization/lighthouse_initial_estimator.py @@ -26,9 +26,10 @@ import numpy as np import numpy.typing as npt -from .ippe_cf import IppeCf +from cflib.localization.lighthouse_cf_pose_sample import BsPairPoses +from cflib.localization.lighthouse_cf_pose_sample import LhCfPoseSample +from cflib.localization.lighthouse_geometry_solution import LighthouseGeometrySolution from cflib.localization.lighthouse_types import LhBsCfPoses -from cflib.localization.lighthouse_types import LhCfPoseSample from cflib.localization.lighthouse_types import LhException from cflib.localization.lighthouse_types import Pose @@ -42,12 +43,6 @@ class BsPairIds(NamedTuple): bs2: int -class BsPairPoses(NamedTuple): - """A type representing the poses of a pair of base stations""" - bs1: Pose - bs2: Pose - - class LighthouseInitialEstimator: """ Make initial estimates of base station and CF poses using IPPE (analytical solution). @@ -58,41 +53,38 @@ class LighthouseInitialEstimator: OUTLIER_DETECTION_ERROR = 0.5 @classmethod - def estimate(cls, matched_samples: list[LhCfPoseSample], sensor_positions: ArrayFloat) -> tuple[ - LhBsCfPoses, list[LhCfPoseSample]]: + def estimate(cls, matched_samples: list[LhCfPoseSample], + solution: LighthouseGeometrySolution) -> tuple[LhBsCfPoses, list[LhCfPoseSample]]: """ Make a rough estimate of the poses of all base stations and CF poses found in the samples. The pose of the Crazyflie in the first sample is used as a reference and will define the global reference frame. - :param matched_samples: A list of samples with lighthouse angles. - :param sensor_positions: An array with the sensor positions on the lighthouse deck (3D, CF ref frame) + :param matched_samples: A list of samples with lighthouse angles. It is assumed that all samples have data for + two or more base stations. + :param solution: A LighthouseGeometrySolution object to store progress information and issues in :return: an estimate of base station and Crazyflie poses, as well as a cleaned version of matched_samples where outliers are removed. """ - bs_positions = cls._find_solutions(matched_samples, sensor_positions) + bs_positions = cls._find_bs_to_bs_poses(matched_samples) # bs_positions is a map from bs-id-pair to position, where the position is the position of the second # bs, as seen from the first bs (in the first bs ref frame). - bs_poses_ref_cfs, cleaned_matched_samples = cls._angles_to_poses( - matched_samples, sensor_positions, bs_positions) - - # Use the first CF pose as the global reference frame. The pose of the first base station (as estimated by ippe) - # is used as the "true" position (reference) - reference_bs_pose = None - for bs_pose_ref_cfs in bs_poses_ref_cfs: - if len(bs_pose_ref_cfs) > 0: - bs_id, reference_bs_pose = list(bs_pose_ref_cfs.items())[0] - break - - if reference_bs_pose is None: - raise LhException('Too little data, no reference') - bs_poses: dict[int, Pose] = {bs_id: reference_bs_pose} + bs_poses_ref_cfs, cleaned_matched_samples = cls._angles_to_poses(matched_samples, bs_positions, solution) + cls._build_link_stats(cleaned_matched_samples, solution) + if not solution.progress_is_ok: + return LhBsCfPoses(bs_poses={}, cf_poses=[]), cleaned_matched_samples - # Calculate the pose of the remaining base stations, based on the pose of the first CF - cls._estimate_remaining_bs_poses(bs_poses_ref_cfs, bs_poses) + # Calculate the pose of all base stations, based on the pose of one base station + try: + bs_poses = cls._estimate_bs_poses(bs_poses_ref_cfs) + except LhException as e: + # At this point we might have too few base stations or we have islands of non-linked base stations. + solution.progress_is_ok = False + solution.general_failure_info = str(e) + return LhBsCfPoses(bs_poses={}, cf_poses=[]), cleaned_matched_samples # Now that we have estimated the base station poses, estimate the poses of the CF in all the samples cf_poses = cls._estimate_cf_poses(bs_poses_ref_cfs, bs_poses) @@ -100,8 +92,30 @@ def estimate(cls, matched_samples: list[LhCfPoseSample], sensor_positions: Array return LhBsCfPoses(bs_poses, cf_poses), cleaned_matched_samples @classmethod - def _find_solutions(cls, matched_samples: list[LhCfPoseSample], sensor_positions: ArrayFloat - ) -> dict[BsPairIds, ArrayFloat]: + def _build_link_stats(cls, matched_samples: list[LhCfPoseSample], solution: LighthouseGeometrySolution) -> None: + """ + Build statistics about the number of links between base stations, based on the matched samples. + :param matched_samples: List of matched samples + :param solution: A LighthouseGeometry Solution object to store issues in + """ + + def increase_link_count(bs1: int, bs2: int): + """Increase the link count between two base stations""" + if bs1 not in solution.link_count: + solution.link_count[bs1] = {} + if bs2 not in solution.link_count[bs1]: + solution.link_count[bs1][bs2] = 0 + solution.link_count[bs1][bs2] += 1 + + for sample in matched_samples: + bs_in_sample = sample.angles_calibrated.keys() + for bs1 in bs_in_sample: + for bs2 in bs_in_sample: + if bs1 != bs2: + increase_link_count(bs1, bs2) + + @classmethod + def _find_bs_to_bs_poses(cls, matched_samples: list[LhCfPoseSample]) -> dict[BsPairIds, ArrayFloat]: """ Find the pose of all base stations, in the reference frame of other base stations. @@ -113,7 +127,6 @@ def _find_solutions(cls, matched_samples: list[LhCfPoseSample], sensor_positions out in space, while the correct one will end up more or less in the same spot for all samples. :param matched_samples: List of matched samples - :param sensor_positions: list of sensor positions on the lighthouse deck, CF reference frame :return: Base stations poses in the reference frame of the other base stations. The data is organized as a dictionary of tuples with base station id pairs, mapped to positions. For instance the entry with key (2, 1) contains the position of base station 1, in the base station 2 reference frame. @@ -121,14 +134,7 @@ def _find_solutions(cls, matched_samples: list[LhCfPoseSample], sensor_positions position_permutations: dict[BsPairIds, list[list[ArrayFloat]]] = {} for sample in matched_samples: - solutions: dict[int, BsPairPoses] = {} - for bs, angles in sample.angles_calibrated.items(): - projections = angles.projection_pair_list() - estimates_ref_bs = IppeCf.solve(sensor_positions, projections) - estimates_ref_cf = cls._convert_estimates_to_cf_reference_frame(estimates_ref_bs) - solutions[bs] = estimates_ref_cf - - cls._add_solution_permutations(solutions, position_permutations) + cls._add_solution_permutations(sample.ippe_solutions, position_permutations) return cls._find_most_likely_positions(position_permutations) @@ -168,32 +174,26 @@ def _add_solution_permutations(cls, solutions: dict[int, BsPairPoses], pose3.translation, pose4.translation]) @classmethod - def _angles_to_poses(cls, matched_samples: list[LhCfPoseSample], sensor_positions: ArrayFloat, - bs_positions: dict[BsPairIds, ArrayFloat]) -> tuple[list[dict[int, Pose]], - list[LhCfPoseSample]]: + def _angles_to_poses(cls, matched_samples: list[LhCfPoseSample], bs_positions: dict[BsPairIds, ArrayFloat], + solution: LighthouseGeometrySolution) -> tuple[list[dict[int, Pose]], list[LhCfPoseSample]]: """ Estimate the base station poses in the Crazyflie reference frames, for each sample. - Use Ippe again to find the possible poses of the bases stations and pick the one that best matches the position - in bs_positions. + Again use the IPPE solutions to find the possible poses of the base stations and pick the one that best matches + the position in bs_positions. :param matched_samples: List of samples - :param sensor_positions: Positions of the sensors on the lighthouse deck (CF ref frame) :param bs_positions: Dictionary of base station positions (other base station ref frame) + :param solution: A LighthouseGeometrySolution object to store issues in :return: A list of dictionaries from base station to Pose of all base stations, for each sample, as well as - a version of the matched_samples where outliers are removed + a version of the matched_samples where outliers are removed. """ result: list[dict[int, Pose]] = [] cleaned_matched_samples: list[LhCfPoseSample] = [] for sample in matched_samples: - solutions: dict[int, BsPairPoses] = {} - for bs, angles in sample.angles_calibrated.items(): - projections = angles.projection_pair_list() - estimates_ref_bs = IppeCf.solve(sensor_positions, projections) - estimates_ref_cf = cls._convert_estimates_to_cf_reference_frame(estimates_ref_bs) - solutions[bs] = estimates_ref_cf + solutions = sample.ippe_solutions poses: dict[int, Pose] = {} ids = sorted(solutions.keys()) @@ -210,9 +210,13 @@ def _angles_to_poses(cls, matched_samples: list[LhCfPoseSample], sensor_position poses[pair_ids.bs2] = pair_poses.bs2 else: is_sample_valid = False + if sample.is_mandatory: + solution.append_mandatory_issue_sample(sample, 'Outlier detected') + else: + solution.xyz_space_samples_info = 'Sample(s) with outliers skipped' break - if is_sample_valid: + if is_sample_valid or sample.is_mandatory: result.append(poses) cleaned_matched_samples.append(sample) @@ -297,20 +301,7 @@ def _find_best_position_bucket(cls, buckets: list[list[ArrayFloat]]) -> ArrayFlo return pos @classmethod - def _convert_estimates_to_cf_reference_frame(cls, estimates_ref_bs: list[IppeCf.Solution]) -> BsPairPoses: - """ - Convert the two ippe solutions from the base station reference frame to the CF reference frame - """ - rot_1 = estimates_ref_bs[0].R.transpose() - t_1 = np.dot(rot_1, -estimates_ref_bs[0].t) - - rot_2 = estimates_ref_bs[1].R.transpose() - t_2 = np.dot(rot_2, -estimates_ref_bs[1].t) - - return BsPairPoses(Pose(rot_1, t_1), Pose(rot_2, t_2)) - - @classmethod - def _estimate_remaining_bs_poses(cls, bs_poses_ref_cfs: list[dict[int, Pose]], bs_poses: dict[int, Pose]) -> None: + def _estimate_bs_poses(cls, bs_poses_ref_cfs: list[dict[int, Pose]]) -> dict[int, Pose]: """ Based on one base station pose, estimate the other base station poses. @@ -318,6 +309,18 @@ def _estimate_remaining_bs_poses(cls, bs_poses_ref_cfs: list[dict[int, Pose]], b have information of base station pairs (0, 2) and (2, 3), from this we can first derive the pose of 2 and after that the pose of 3. """ + # Use the first CF pose as the global reference frame. The pose of the first base station (as estimated by ippe) + # is used as the pose that all other base stations are mapped to. + reference_bs_pose = None + for bs_pose_ref_cfs in bs_poses_ref_cfs: + if len(bs_pose_ref_cfs) > 0: + bs_id, reference_bs_pose = list(bs_pose_ref_cfs.items())[0] + break + + if reference_bs_pose is None: + raise LhException('Too little data, no reference') + bs_poses: dict[int, Pose] = {bs_id: reference_bs_pose} + # Find all base stations in the list all_bs = set() for initial_est_bs_poses in bs_poses_ref_cfs: @@ -327,9 +330,11 @@ def _estimate_remaining_bs_poses(cls, bs_poses_ref_cfs: list[dict[int, Pose]], b to_find = all_bs - bs_poses.keys() # run through the list of samples until we manage to find them all - remaining = len(to_find) - while remaining > 0: - buckets: dict[int, list[Pose]] = {} + # The process is like peeling an onion, from the inside out. In each iteration we find the poses of + # the base stations that are closest to the ones we already have, until we have found all poses. + remaining_to_find = len(to_find) + while remaining_to_find > 0: + averaging_storage: dict[int, list[Pose]] = {} for bs_poses_in_sample in bs_poses_ref_cfs: unknown = to_find.intersection(bs_poses_in_sample.keys()) known = set(bs_poses.keys()).intersection(bs_poses_in_sample.keys()) @@ -348,25 +353,30 @@ def _estimate_remaining_bs_poses(cls, bs_poses_ref_cfs: list[dict[int, Pose]], b unknown_cf = bs_poses_in_sample[bs_id] # Finally we can calculate the BS pose in the global reference frame bs_pose = cls._map_pose_to_ref_frame(known_global, known_cf, unknown_cf) - if bs_id not in buckets: - buckets[bs_id] = [] - buckets[bs_id].append(bs_pose) + if bs_id not in averaging_storage: + averaging_storage[bs_id] = [] + averaging_storage[bs_id].append(bs_pose) # Average over poses and add to bs_poses - for bs_id, poses in buckets.items(): - bs_poses[bs_id] = cls._avarage_poses(poses) + for bs_id, poses in averaging_storage.items(): + bs_poses[bs_id] = cls._average_poses(poses) + # Remove the newly found base stations from the set of base stations to find to_find = all_bs - bs_poses.keys() if len(to_find) == 0: break - if len(to_find) == remaining: + if len(to_find) == remaining_to_find: + # We could not map any more poses, but some still remain to be found. This means that there are not + # links to all base stations. raise LhException('Can not link positions between all base stations') - remaining = len(to_find) + remaining_to_find = len(to_find) + + return bs_poses @classmethod - def _avarage_poses(cls, poses: list[Pose]) -> Pose: + def _average_poses(cls, poses: list[Pose]) -> Pose: """ Averaging of quaternions to get the "average" orientation of multiple samples. From https://stackoverflow.com/a/61013769 @@ -400,7 +410,7 @@ def _estimate_cf_poses(cls, bs_poses_ref_cfs: list[dict[int, Pose]], bs_poses: d est_ref_global = cls._map_cf_pos_to_cf_pos(pose_global, pose_cf) poses.append(est_ref_global) - cf_poses.append(cls._avarage_poses(poses)) + cf_poses.append(cls._average_poses(poses)) return cf_poses diff --git a/cflib/localization/lighthouse_sample_matcher.py b/cflib/localization/lighthouse_sample_matcher.py index afe26fb8f..ccc9c94d4 100644 --- a/cflib/localization/lighthouse_sample_matcher.py +++ b/cflib/localization/lighthouse_sample_matcher.py @@ -21,7 +21,8 @@ # along with this program. If not, see . from __future__ import annotations -from cflib.localization.lighthouse_types import LhCfPoseSample +from cflib.localization.lighthouse_bs_vector import LighthouseBsVectors +from cflib.localization.lighthouse_cf_pose_sample import LhCfPoseSample from cflib.localization.lighthouse_types import LhMeasurement @@ -33,32 +34,75 @@ class LighthouseSampleMatcher: a list of LhCfPoseSample. Matching is done using the timestamp and a maximum time span. """ + def __init__(self, max_time_diff: float = 0.020, min_nr_of_bs_in_match: int = 1) -> None: + self.max_time_diff = max_time_diff + self.min_nr_of_bs_in_match = min_nr_of_bs_in_match + + self._current_angles: dict[int, LighthouseBsVectors] = {} + self._current_ts = 0.0 + + def match_one(self, sample: LhMeasurement) -> LhCfPoseSample | None: + """Aggregate samples close in time. + This function is used to match samples from multiple base stations into a single LhCfPoseSample. + The function will return None if the number of base stations in the sample is less than + the minimum number of base stations required for a match. + Note that a pose sample is returned upon the next call to this function, that is when the maximum time diff of + the first sample in the group has been exceeded. + + Args: + sample (LhMeasurement): angles from one base station + + Returns: + LhCfPoseSample | None: a pose sample if available, otherwise None + """ + result = None + if len(self._current_angles) > 0: + if sample.timestamp > (self._current_ts + self.max_time_diff): + if len(self._current_angles) >= self.min_nr_of_bs_in_match: + result = LhCfPoseSample(self._current_angles, timestamp=self._current_ts) + + self._current_angles = {} + + if len(self._current_angles) == 0: + self._current_ts = sample.timestamp + + self._current_angles[sample.base_station_id] = sample.angles + + return result + + def purge(self) -> LhCfPoseSample | None: + """Purge the current angles and return a pose sample if available. + + Returns: + LhCfPoseSample | None: a pose sample if available, otherwise None + """ + result = None + + if len(self._current_angles) >= self.min_nr_of_bs_in_match: + result = LhCfPoseSample(self._current_angles, timestamp=self._current_ts) + + self._current_angles = {} + self._current_ts = 0.0 + + return result + @classmethod def match(cls, samples: list[LhMeasurement], max_time_diff: float = 0.020, - min_nr_of_bs_in_match: int = 0) -> list[LhCfPoseSample]: + min_nr_of_bs_in_match: int = 1) -> list[LhCfPoseSample]: """ - Aggregate samples close in time into lists + Aggregate samples in a list """ result = [] - current: LhCfPoseSample = None + matcher = cls(max_time_diff, min_nr_of_bs_in_match) for sample in samples: - ts = sample.timestamp - - if current is None: - current = LhCfPoseSample(timestamp=ts) - - if ts > (current.timestamp + max_time_diff): - cls._append_result(current, result, min_nr_of_bs_in_match) - current = LhCfPoseSample(timestamp=ts) + pose_sample = matcher.match_one(sample) + if pose_sample is not None: + result.append(pose_sample) - current.angles_calibrated[sample.base_station_id] = sample.angles + pose_sample = matcher.purge() + if pose_sample is not None: + result.append(pose_sample) - cls._append_result(current, result, min_nr_of_bs_in_match) return result - - @classmethod - def _append_result(cls, current: LhCfPoseSample, result: list[LhCfPoseSample], min_nr_of_bs_in_match: int): - if current is not None and len(current.angles_calibrated) >= min_nr_of_bs_in_match: - result.append(current) diff --git a/cflib/localization/lighthouse_sweep_angle_reader.py b/cflib/localization/lighthouse_sweep_angle_reader.py index 8c653c4a5..aa0652000 100644 --- a/cflib/localization/lighthouse_sweep_angle_reader.py +++ b/cflib/localization/lighthouse_sweep_angle_reader.py @@ -19,8 +19,13 @@ # # You should have received a copy of the GNU General Public License # along with this program. If not, see . +from collections.abc import Callable +from threading import Timer + +from cflib.crazyflie import Crazyflie from cflib.localization import LighthouseBsVector from cflib.localization.lighthouse_bs_vector import LighthouseBsVectors +from cflib.localization.lighthouse_cf_pose_sample import LhCfPoseSample class LighthouseSweepAngleReader(): @@ -30,7 +35,7 @@ class LighthouseSweepAngleReader(): ANGLE_STREAM_PARAM = 'locSrv.enLhAngleStream' NR_OF_SENSORS = 4 - def __init__(self, cf, data_recevied_cb): + def __init__(self, cf: Crazyflie, data_recevied_cb): self._cf = cf self._cb = data_recevied_cb self._is_active = False @@ -48,7 +53,7 @@ def stop(self): self._cf.loc.receivedLocationPacket.remove_callback(self._packet_received_cb) self._angle_stream_activate(False) - def _angle_stream_activate(self, is_active): + def _angle_stream_activate(self, is_active: bool): value = 0 if is_active: value = 1 @@ -59,11 +64,11 @@ def _packet_received_cb(self, packet): return if self._cb: - base_station_id = packet.data['basestation'] - horiz_angles = packet.data['x'] - vert_angles = packet.data['y'] + base_station_id: int = packet.data['basestation'] + horiz_angles: float = packet.data['x'] + vert_angles: float = packet.data['y'] - result = [] + result: list[LighthouseBsVector] = [] for i in range(self.NR_OF_SENSORS): result.append(LighthouseBsVector(horiz_angles[i], vert_angles[i])) @@ -75,7 +80,7 @@ class LighthouseSweepAngleAverageReader(): Helper class to make it easy read sweep angles for multiple base stations and average the result """ - def __init__(self, cf, ready_cb): + def __init__(self, cf: Crazyflie, ready_cb: Callable[[dict[int, tuple[int, LighthouseBsVectors]]], None]): self._reader = LighthouseSweepAngleReader(cf, self._data_recevied_cb) self._ready_cb = ready_cb self.nr_of_samples_required = 50 @@ -84,7 +89,7 @@ def __init__(self, cf, ready_cb): # The storage is a dictionary keyed on the base station channel # Each entry is a list of 4 lists, one per sensor. # Each list contains LighthouseBsVector objects, representing the sampled sweep angles - self._sample_storage = None + self._sample_storage: dict[int, list[list[LighthouseBsVector]]] | None = None def start_angle_collection(self): """ @@ -103,16 +108,18 @@ def is_collecting(self): """True if data collection is in progress""" return self._sample_storage is not None - def _data_recevied_cb(self, base_station_id, bs_vectors): - self._store_sample(base_station_id, bs_vectors, self._sample_storage) - if self._has_collected_enough_data(self._sample_storage): - self._reader.stop() - if self._ready_cb: - averages = self._average_all_lists(self._sample_storage) - self._ready_cb(averages) - self._sample_storage = None + def _data_recevied_cb(self, base_station_id: int, bs_vectors: list[LighthouseBsVector]): + if self._sample_storage is not None: + self._store_sample(base_station_id, bs_vectors, self._sample_storage) + if self._has_collected_enough_data(self._sample_storage): + self._reader.stop() + if self._ready_cb: + averages = self._average_all_lists(self._sample_storage) + self._ready_cb(averages) + self._sample_storage = None - def _store_sample(self, base_station_id, bs_vectors, storage): + def _store_sample(self, base_station_id: int, bs_vectors: list[LighthouseBsVector], + storage: dict[int, list[list[LighthouseBsVector]]]): if base_station_id not in storage: storage[base_station_id] = [] for sensor in range(self._reader.NR_OF_SENSORS): @@ -121,31 +128,32 @@ def _store_sample(self, base_station_id, bs_vectors, storage): for sensor in range(self._reader.NR_OF_SENSORS): storage[base_station_id][sensor].append(bs_vectors[sensor]) - def _has_collected_enough_data(self, storage): + def _has_collected_enough_data(self, storage: dict[int, list[list[LighthouseBsVector]]]): for sample_list in storage.values(): if len(sample_list[0]) >= self.nr_of_samples_required: return True return False - def _average_all_lists(self, storage): - result = {} + def _average_all_lists(self, storage: dict[int, list[list[LighthouseBsVector]]] + ) -> dict[int, tuple[int, LighthouseBsVectors]]: + result: dict[int, tuple[int, LighthouseBsVectors]] = {} - for id, sample_lists in storage.items(): + for bs_id, sample_lists in storage.items(): averages = self._average_sample_lists(sample_lists) count = len(sample_lists[0]) - result[id] = (count, averages) + result[bs_id] = (count, averages) return result - def _average_sample_lists(self, sample_lists): - result = [] + def _average_sample_lists(self, sample_lists: list[list[LighthouseBsVector]]) -> LighthouseBsVectors: + result: list[LighthouseBsVector] = [] for i in range(self._reader.NR_OF_SENSORS): result.append(self._average_sample_list(sample_lists[i])) return LighthouseBsVectors(result) - def _average_sample_list(self, sample_list): + def _average_sample_list(self, sample_list: list[LighthouseBsVector]) -> LighthouseBsVector: sum_horiz = 0.0 sum_vert = 0.0 @@ -155,3 +163,117 @@ def _average_sample_list(self, sample_list): count = len(sample_list) return LighthouseBsVector(sum_horiz / count, sum_vert / count) + + +class LighthouseMatchedSweepAngleReader(): + """ + Wrapper to simplify reading of matched lighthouse sweep angles from the locSrv stream + """ + MATCHED_STREAM_PARAM = 'locSrv.enLhMtchStm' + MATCHED_STREAM_MIN_BS_PARAM = 'locSrv.minBsLhMtchStm' + MATCHED_STREAM_MAX_TIME_PARAM = 'locSrv.maxTimeLhMtchStm' + NR_OF_SENSORS = 4 + + def __init__(self, cf: Crazyflie, data_recevied_cb, timeout_cb=None, sample_count: int = 1, min_bs: int = 2, + max_time_ms: int = 25): + self._cf = cf + self._data_cb = data_recevied_cb + self._timeout_cb = timeout_cb + self._is_active = False + self._sample_count = sample_count + self._sample_count_remaining = 0 + + # The maximum number of base stations is limited in the CF due to memory considerations. + if min_bs > 4: + raise ValueError('Minimum base station count must be 4 or less') + self._min_bs = min_bs + + self._max_time_ms = max_time_ms + + self._current_group_id = 0 + self._angles: dict[int, LighthouseBsVectors] = {} + + self._timeout_timer = None + + def start(self, timeout: float = 0.0): + """Start reading sweep angles + + Args: + timeout (float): timeout in seconds, 0.0 means no timeout + """ + self._cf.loc.receivedLocationPacket.add_callback(self._packet_received_cb) + self._is_active = True + self._angle_stream_activate(True) + self._sample_count_remaining = self._sample_count + + self._clear_timer() + self._timeout_timer = Timer(timeout, self._timer_done_cb) + self._timeout_timer.start() + + def stop(self): + """Stop reading sweep angles""" + if self._is_active: + self._is_active = False + self._clear_timer() + self._cf.loc.receivedLocationPacket.remove_callback(self._packet_received_cb) + self._angle_stream_activate(False) + + def _clear_timer(self): + if self._timeout_timer is not None: + self._timeout_timer.cancel() + self._timeout_timer = None + + def _timer_done_cb(self): + self.stop() + if self._timeout_cb: + self._timeout_cb() + + def _angle_stream_activate(self, is_active: bool): + value = 0 + if is_active: + value = self._sample_count + self._cf.param.set_value(self.MATCHED_STREAM_PARAM, value) + + self._cf.param.set_value(self.MATCHED_STREAM_MIN_BS_PARAM, self._min_bs) + self._cf.param.set_value(self.MATCHED_STREAM_MAX_TIME_PARAM, self._max_time_ms) + + def _packet_received_cb(self, packet): + if self._is_active: + if packet.type != self._cf.loc.LH_MATCHED_ANGLE_STREAM: + return + + base_station_id: int = packet.data['basestation'] + horiz_angles: float = packet.data['x'] + vert_angles: float = packet.data['y'] + group_id: int = packet.data['group_id'] + bs_count: int = packet.data['bs_count'] + + if group_id != self._current_group_id: + if len(self._angles) >= self._min_bs: + # We have enough angles in the previous group even though all angles were not received + # Lost a packet? + self._call_data_callback() + + # Reset + self._current_group_id = group_id + self._angles = {} + + vectors: list[LighthouseBsVector] = [] + for i in range(self.NR_OF_SENSORS): + vectors.append(LighthouseBsVector(horiz_angles[i], vert_angles[i])) + self._angles[base_station_id] = LighthouseBsVectors(vectors) + + if len(self._angles) == bs_count: + # We have received all angles for this group, call the callback + self._call_data_callback() + + if self._sample_count_remaining <= 0: + # We have received enough samples, stop the reader + self.stop() + + def _call_data_callback(self): + self._sample_count_remaining -= 1 + + if self._data_cb: + self._data_cb(LhCfPoseSample(self._angles)) + self._angles = {} diff --git a/cflib/localization/lighthouse_system_aligner.py b/cflib/localization/lighthouse_system_aligner.py index 3ec964a0a..158fc1903 100644 --- a/cflib/localization/lighthouse_system_aligner.py +++ b/cflib/localization/lighthouse_system_aligner.py @@ -109,21 +109,25 @@ def _Pose_from_params(cls, params: npt.ArrayLike) -> Pose: def _de_flip_transformation(cls, raw_transformation: Pose, x_axis: list[npt.ArrayLike], bs_poses: dict[int, Pose]) -> Pose: """ - Investigats a transformation and flips it if needed. This method assumes that - 1. all base stations are at Z>0 - 2. x_axis samples are taken at X>0 + Examines a transformation and flips it if needed. This method assumes that + 1. most base stations are at Z > 0 + 2. x_axis samples are taken at X > 0 """ transformation = raw_transformation - # X-axis poses should be on the positivie X-axis, check that the "mean" of the x-axis points ends up at X>0 + # X-axis poses should be on the positive X-axis, check that the "mean" of the x-axis points ends up at X>0 x_axis_mean = np.mean(x_axis, axis=0) if raw_transformation.rotate_translate(x_axis_mean)[0] < 0.0: flip_around_z_axis = Pose.from_rot_vec(R_vec=(0.0, 0.0, np.pi)) transformation = flip_around_z_axis.rotate_translate_pose(transformation) - # Base station poses should be above the floor, check the first one - bs_pose = list(bs_poses.values())[0] - if raw_transformation.rotate_translate(bs_pose.translation)[2] < 0.0: + # Assume base station poses should be above the floor. It is possible that the estimate of one or a few of them + # is slightly negative if they are placed on the floor, use an average of the z of all base stations. + def rotate_translate_get_z(bs_pose: Pose) -> float: + return raw_transformation.rotate_translate(bs_pose.translation)[2] + + bs_z_mean = np.mean(list(map(rotate_translate_get_z, bs_poses.values()))) + if bs_z_mean < 0.0: flip_around_x_axis = Pose.from_rot_vec(R_vec=(np.pi, 0.0, 0.0)) transformation = flip_around_x_axis.rotate_translate_pose(transformation) diff --git a/cflib/localization/lighthouse_system_scaler.py b/cflib/localization/lighthouse_system_scaler.py index a873dba03..8e9469328 100644 --- a/cflib/localization/lighthouse_system_scaler.py +++ b/cflib/localization/lighthouse_system_scaler.py @@ -27,8 +27,8 @@ import numpy.typing as npt from cflib.localization.lighthouse_bs_vector import LighthouseBsVector -from cflib.localization.lighthouse_types import LhCfPoseSample -from cflib.localization.lighthouse_types import Pose +from cflib.localization.lighthouse_cf_pose_sample import LhCfPoseSample +from cflib.localization.lighthouse_cf_pose_sample import Pose class LighthouseSystemScaler: diff --git a/cflib/localization/lighthouse_types.py b/cflib/localization/lighthouse_types.py index 941bc5e74..23aa29683 100644 --- a/cflib/localization/lighthouse_types.py +++ b/cflib/localization/lighthouse_types.py @@ -25,6 +25,7 @@ import numpy as np import numpy.typing as npt +import yaml from scipy.spatial.transform import Rotation from cflib.localization.lighthouse_bs_vector import LighthouseBsVectors @@ -135,6 +136,32 @@ def inv_rotate_translate_pose(self, pose: 'Pose') -> 'Pose': return Pose(R_matrix=R, t_vec=t) + def __eq__(self, other): + if not isinstance(other, Pose): + return NotImplemented + + return np.array_equal(self._R_matrix, other._R_matrix) and np.array_equal(self._t_vec, other._t_vec) + + @staticmethod + def yaml_representer(dumper, data: Pose): + """Represent a Pose object in YAML""" + return dumper.represent_mapping('!Pose', { + 'R_matrix': data.rot_matrix.tolist(), + 't_vec': data.translation.tolist() + }) + + @staticmethod + def yaml_constructor(loader, node): + """Construct a Pose object from YAML""" + values = loader.construct_mapping(node, deep=True) + R_matrix = np.array(values['R_matrix']) + t_vec = np.array(values['t_vec']) + return Pose(R_matrix=R_matrix, t_vec=t_vec) + + +yaml.add_representer(Pose, Pose.yaml_representer) +yaml.add_constructor('!Pose', Pose.yaml_constructor) + class LhMeasurement(NamedTuple): """Represents a measurement from one base station.""" @@ -149,25 +176,6 @@ class LhBsCfPoses(NamedTuple): cf_poses: list[Pose] -class LhCfPoseSample: - """ Represents a sample of a Crazyflie pose in space, it contains - various data related to the pose such as: - - lighthouse angles from one or more base stations - - initial estimate of the pose - - refined estimate of the pose - - estimated errors - """ - - def __init__(self, timestamp: float = 0.0, angles_calibrated: dict[int, LighthouseBsVectors] = None) -> None: - self.timestamp: float = timestamp - - # Angles measured by the Crazyflie and compensated using calibration data - # Stored in a dictionary using base station id as the key - self.angles_calibrated: dict[int, LighthouseBsVectors] = angles_calibrated - if self.angles_calibrated is None: - self.angles_calibrated = {} - - class LhDeck4SensorPositions: """ Positions of the sensors on the Lighthouse 4 deck """ # Sensor distances on the lighthouse deck diff --git a/cflib/localization/lighthouse_utils.py b/cflib/localization/lighthouse_utils.py new file mode 100644 index 000000000..9e614e7ad --- /dev/null +++ b/cflib/localization/lighthouse_utils.py @@ -0,0 +1,189 @@ +# -*- coding: utf-8 -*- +# +# ,---------, ____ _ __ +# | ,-^-, | / __ )(_) /_______________ _____ ___ +# | ( O ) | / __ / / __/ ___/ ___/ __ `/_ / / _ \ +# | / ,--' | / /_/ / / /_/ /__/ / / /_/ / / /_/ __/ +# +------` /_____/_/\__/\___/_/ \__,_/ /___/\___/ +# +# Copyright (C) 2025 Bitcraze AB +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, in version 3. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +import numpy as np +import numpy.typing as npt + +from cflib.localization.lighthouse_bs_vector import LighthouseBsVector +from cflib.localization.lighthouse_bs_vector import LighthouseBsVectors +from cflib.localization.lighthouse_types import Pose + + +class LighthouseCrossingBeam: + """A class to calculate the crossing point of two "beams" from two base stations. The beams are defined by the line + where the two light planes intersect. In a perfect world the crossing point of the two beams is the position of + a sensor on the Crazyflie Lighthouse deck, but in reality the beams will most likely not cross and instead we + use the closest point between the two beams as the position estimate. The (minimum) distance between the beams + is also calculated and can be used as an error estimate for the position. + """ + + @classmethod + def position_distance(cls, + bs1: Pose, angles_bs1: LighthouseBsVector, + bs2: Pose, angles_bs2: LighthouseBsVector) -> tuple[npt.NDArray, float]: + """Calculate the estimated position of the crossing point of the beams + from two base stations as well as the distance. + + Args: + bs1 (Pose): The pose of the first base station. + angles_bs1 (LighthouseBsVector): The sweep angles of the first base station. + bs2 (Pose): The pose of the second base station. + angles_bs2 (LighthouseBsVector): The sweep angles of the second base station. + + Returns: + tuple[npt.NDArray, float]: The estimated position of the crossing point and the distance between the beams. + """ + orig_1 = bs1.translation + vec_1 = bs1.rot_matrix @ angles_bs1.cart + + orig_2 = bs2.translation + vec_2 = bs2.rot_matrix @ angles_bs2.cart + + return cls._position_distance(orig_1, vec_1, orig_2, vec_2) + + @classmethod + def position(cls, + bs1: Pose, angles_bs1: LighthouseBsVector, + bs2: Pose, angles_bs2: LighthouseBsVector) -> npt.NDArray: + """Calculate the estimated position of the crossing point of the beams + from two base stations. + + Args: + bs1 (Pose): The pose of the first base station. + angles_bs1 (LighthouseBsVector): The sweep angles of the first base station. + bs2 (Pose): The pose of the second base station. + angles_bs2 (LighthouseBsVector): The sweep angles of the second base station. + + Returns: + npt.NDArray: The estimated position of the crossing point of the two beams. + """ + position, _ = cls.position_distance(bs1, angles_bs1, bs2, angles_bs2) + return position + + @classmethod + def distance(cls, + bs1: Pose, angles_bs1: LighthouseBsVector, + bs2: Pose, angles_bs2: LighthouseBsVector) -> float: + """Calculate the minimum distance between the beams from two base stations. + + Args: + bs1 (Pose): The pose of the first base station. + angles_bs1 (LighthouseBsVector): The sweep angles of the first base station. + bs2 (Pose): The pose of the second base station. + angles_bs2 (LighthouseBsVector): The sweep angles of the second base station. + + Returns: + float: The shortest distance between the beams. + """ + _, distance = cls.position_distance(bs1, angles_bs1, bs2, angles_bs2) + return distance + + @classmethod + def distances(cls, + bs1: Pose, angles_bs1: LighthouseBsVectors, + bs2: Pose, angles_bs2: LighthouseBsVectors) -> list[float]: + """Calculate the minimum distance between the beams from two base stations for all sensors. + + Args: + bs1 (Pose): The pose of the first base station. + angles_bs1 (LighthouseBsVectors): The sweep angles of the first base station. + bs2 (Pose): The pose of the second base station. + angles_bs2 (LighthouseBsVectors): The sweep angles of the second base station. + + Returns: + list[float]: A list of the distances. + """ + return [cls.distance(bs1, angles1, bs2, angles2) for angles1, angles2 in zip(angles_bs1, angles_bs2)] + + @classmethod + def max_distance(cls, + bs1: Pose, angles_bs1: LighthouseBsVectors, + bs2: Pose, angles_bs2: LighthouseBsVectors) -> float: + """Calculate the maximum distance between the beams from two base stations for all sensors. + + Args: + bs1 (Pose): The pose of the first base station. + angles_bs1 (LighthouseBsVectors): The sweep angles of the first base station. + bs2 (Pose): The pose of the second base station. + angles_bs2 (LighthouseBsVectors): The sweep angles of the second base station. + + Returns: + float: The maximum distance between the beams. + """ + return max(cls.distances(bs1, angles_bs1, bs2, angles_bs2)) + + @classmethod + def max_distance_all_permutations(cls, bs_angles: list[tuple[Pose, LighthouseBsVectors]]) -> float: + """Calculate the maximum distance between the beams from base stations for all sensors. All permutations of + base stations are considered. This result can be used as an estimation of the maximum error. + + Args: + bs_angles (list[tuple[Pose, LighthouseBsVectors]]): A list of tuples containing the pose of the base + stations and their sweep angles. + + Returns: + float: The maximum distance between the beams from all permutations of base stations. + """ + if len(bs_angles) < 2: + raise ValueError('At least two base stations are required to calculate the maximum distance.') + + max_distance = 0.0 + bs_count = len(bs_angles) + for i1 in range(bs_count - 1): + for i2 in range(i1 + 1, bs_count): + bs1, angles_bs1 = bs_angles[i1] + bs2, angles_bs2 = bs_angles[i2] + # Calculate the distance for this pair of base stations + distance = cls.max_distance(bs1, angles_bs1, bs2, angles_bs2) + max_distance = max(max_distance, distance) + + return max_distance + + @classmethod + def _position_distance(cls, + orig_1: npt.NDArray, vec_1: npt.NDArray, + orig_2: npt.NDArray, vec_2: npt.NDArray) -> tuple[npt.NDArray, float]: + w0 = orig_1 - orig_2 + a = np.dot(vec_1, vec_1) + b = np.dot(vec_1, vec_2) + c = np.dot(vec_2, vec_2) + d = np.dot(vec_1, w0) + e = np.dot(vec_2, w0) + + denom = a * c - b * b + + # Closest point to line 2 on line 1 + t = (b * e - c * d) / denom + pt1 = orig_1 + t * vec_1 + + # Closest point to line 1 on line 2 + t = (a * e - b * d) / denom + pt2 = orig_2 + t * vec_2 + + # Point between the two lines + pt = (pt1 + pt2) / 2 + + # Distance between the two closest points of the beams + distance = np.linalg.norm(pt1 - pt2) + + return pt, float(distance) diff --git a/cflib/localization/user_action_detector.py b/cflib/localization/user_action_detector.py new file mode 100644 index 000000000..278b2a543 --- /dev/null +++ b/cflib/localization/user_action_detector.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- +# +# || ____ _ __ +# +------+ / __ )(_) /_______________ _____ ___ +# | 0xBC | / __ / / __/ ___/ ___/ __ `/_ / / _ \ +# +------+ / /_/ / / /_/ /__/ / / /_/ / / /_/ __/ +# || || /_____/_/\__/\___/_/ \__,_/ /___/\___/ +# +# Copyright (C) 2025 Bitcraze AB +# +# Crazyflie Nano Quadcopter Client +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU General Public License +# as published by the Free Software Foundation; either version 2 +# of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +""" +Functionality to get user input by shaking the Crazyflie. +""" +import time + +from cflib.crazyflie import Crazyflie +from cflib.crazyflie.log import LogConfig + + +class UserActionDetector: + """ This class is used as an user interface that lets the user trigger an event by using the Crazyflie as the + input device. The class listens to the z component of the gyro and detects a quick left or right rotation followed + by period of no motion. If such a sequence is detected, it calls the callback function provided in the constructor. + """ + + def __init__(self, cf: Crazyflie, cb=None): + self._is_active = False + self._reset() + self._cf = cf + self._cb = cb + self._lg_config = None + + self.left_event_threshold_time = 0.0 + self.left_event_time = 0.0 + self.right_event_threshold_time = 0.0 + self.right_event_time = 0 + self.still_event_threshold_time = 0.0 + self.still_event_time = 0.0 + + def start(self): + if not self._is_active: + self._is_active = True + self._reset() + self._cf.disconnected.add_callback(self._disconnected_callback) + + self._lg_config = LogConfig(name='lighthouse_geo_estimator', period_in_ms=25) + self._lg_config.add_variable('gyro.z', 'float') + self._cf.log.add_config(self._lg_config) + self._lg_config.data_received_cb.add_callback(self._log_callback) + self._lg_config.start() + + def stop(self): + if self._is_active: + if self._lg_config is not None: + self._lg_config.stop() + self._lg_config.delete() + self._lg_config.data_received_cb.remove_callback(self._log_callback) + self._lg_config = None + self._cf.disconnected.remove_callback(self._disconnected_callback) + self._is_active = False + + def _disconnected_callback(self, uri): + self.stop() + + def _log_callback(self, ts, data, logblock): + if self._is_active: + gyro_z = data['gyro.z'] + self.process_rot(gyro_z) + + def _reset(self): + self.left_event_threshold_time = 0.0 + self.left_event_time = 0.0 + + self.right_event_threshold_time = 0.0 + self.right_event_time = 0 + + self.still_event_threshold_time = 0.0 + self.still_event_time = 0.0 + + def process_rot(self, gyro_z): + now = time.time() + + MAX_DURATION_OF_EVENT_PEEK = 0.1 + MIN_DURATION_OF_STILL_EVENT = 0.5 + MAX_TIME_BETWEEN_LEFT_RIGHT_EVENTS = 0.3 + MAX_TIME_BETWEEN_FIRST_ROTATION_AND_STILL_EVENT = 1.0 + + if gyro_z > 0: + self.left_event_threshold_time = now + if gyro_z < -300 and now - self.left_event_threshold_time < MAX_DURATION_OF_EVENT_PEEK: + self.left_event_time = now + + if gyro_z < 0: + self.right_event_threshold_time = now + if gyro_z > 300 and now - self.right_event_threshold_time < MAX_DURATION_OF_EVENT_PEEK: + self.right_event_time = now + + if abs(gyro_z) > 50: + self.still_event_threshold_time = now + if abs(gyro_z) < 30 and now - self.still_event_threshold_time > MIN_DURATION_OF_STILL_EVENT: + self.still_event_time = now + + dt_left_right = self.left_event_time - self.right_event_time + first_left_right = min(self.left_event_time, self.right_event_time) + dt_first_still = self.still_event_time - first_left_right + + if self.left_event_time > 0 and self.right_event_time > 0 and self.still_event_time > 0: + if (abs(dt_left_right) < MAX_TIME_BETWEEN_LEFT_RIGHT_EVENTS and + dt_first_still > 0 and + dt_first_still < MAX_TIME_BETWEEN_FIRST_ROTATION_AND_STILL_EVENT): + self._reset() + if self._cb is not None: + self._cb() diff --git a/examples/lighthouse/multi_bs_geometry_estimation.py b/examples/lighthouse/multi_bs_geometry_estimation.py index 4f6dc7c15..bea4f49a1 100644 --- a/examples/lighthouse/multi_bs_geometry_estimation.py +++ b/examples/lighthouse/multi_bs_geometry_estimation.py @@ -44,7 +44,6 @@ from __future__ import annotations import logging -import pickle import time from threading import Event @@ -55,30 +54,32 @@ from cflib.crazyflie.mem.lighthouse_memory import LighthouseBsGeometry from cflib.crazyflie.syncCrazyflie import SyncCrazyflie from cflib.localization.lighthouse_bs_vector import LighthouseBsVectors +from cflib.localization.lighthouse_cf_pose_sample import LhCfPoseSample +from cflib.localization.lighthouse_cf_pose_sample import Pose from cflib.localization.lighthouse_config_manager import LighthouseConfigWriter -from cflib.localization.lighthouse_geometry_solver import LighthouseGeometrySolver -from cflib.localization.lighthouse_initial_estimator import LighthouseInitialEstimator -from cflib.localization.lighthouse_sample_matcher import LighthouseSampleMatcher +from cflib.localization.lighthouse_geo_estimation_manager import LhGeoEstimationManager +from cflib.localization.lighthouse_geo_estimation_manager import LhGeoInputContainer +from cflib.localization.lighthouse_geo_estimation_manager import LhGeoInputContainerData +from cflib.localization.lighthouse_geometry_solution import LighthouseGeometrySolution +from cflib.localization.lighthouse_sweep_angle_reader import LighthouseMatchedSweepAngleReader from cflib.localization.lighthouse_sweep_angle_reader import LighthouseSweepAngleAverageReader from cflib.localization.lighthouse_sweep_angle_reader import LighthouseSweepAngleReader -from cflib.localization.lighthouse_system_aligner import LighthouseSystemAligner -from cflib.localization.lighthouse_system_scaler import LighthouseSystemScaler -from cflib.localization.lighthouse_types import LhCfPoseSample +from cflib.localization.lighthouse_types import LhBsCfPoses from cflib.localization.lighthouse_types import LhDeck4SensorPositions from cflib.localization.lighthouse_types import LhMeasurement -from cflib.localization.lighthouse_types import Pose +from cflib.localization.user_action_detector import UserActionDetector from cflib.utils import uri_helper REFERENCE_DIST = 1.0 -def record_angles_average(scf: SyncCrazyflie, timeout: float = 5.0) -> LhCfPoseSample: +def record_angles_average(scf: SyncCrazyflie, timeout: float = 5.0) -> LhCfPoseSample | None: """Record angles and average over the samples to reduce noise""" - recorded_angles = None + recorded_angles: dict[int, tuple[int, LighthouseBsVectors]] | None = None is_ready = Event() - def ready_cb(averages): + def ready_cb(averages: dict[int, tuple[int, LighthouseBsVectors]]): nonlocal recorded_angles recorded_angles = averages is_ready.set() @@ -90,11 +91,11 @@ def ready_cb(averages): print('Recording timed out.') return None - angles_calibrated = {} + angles_calibrated: dict[int, LighthouseBsVectors] = {} for bs_id, data in recorded_angles.items(): angles_calibrated[bs_id] = data[1] - result = LhCfPoseSample(angles_calibrated=angles_calibrated) + result = LhCfPoseSample(angles_calibrated) visible = ', '.join(map(lambda x: str(x + 1), recorded_angles.keys())) print(f' Position recorded, base station ids visible: {visible}') @@ -106,9 +107,9 @@ def ready_cb(averages): return result -def record_angles_sequence(scf: SyncCrazyflie, recording_time_s: float) -> list[LhCfPoseSample]: +def record_angles_sequence(scf: SyncCrazyflie, recording_time_s: float) -> list[LhMeasurement]: """Record angles and return a list of the samples""" - result: list[LhCfPoseSample] = [] + result: list[LhMeasurement] = [] bs_seen = set() @@ -142,11 +143,11 @@ def parse_recording_time(recording_time: str, default: int) -> int: return default -def print_base_stations_poses(base_stations: dict[int, Pose]): +def print_base_stations_poses(base_stations: dict[int, Pose], printer=print): """Pretty print of base stations pose""" for bs_id, pose in sorted(base_stations.items()): pos = pose.translation - print(f' {bs_id + 1}: ({pos[0]}, {pos[1]}, {pos[2]})') + printer(f' {bs_id + 1}: ({pos[0]}, {pos[1]}, {pos[2]})') def set_axes_equal(ax): @@ -178,15 +179,15 @@ def set_axes_equal(ax): ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius]) -def visualize(cf_poses: list[Pose], bs_poses: list[Pose]): +def visualize(poses: LhBsCfPoses): """Visualize positions of base stations and Crazyflie positions""" # Set to True to visualize positions # Requires PyPlot - visualize_positions = False + visualize_positions = True if visualize_positions: import matplotlib.pyplot as plt - positions = np.array(list(map(lambda x: x.translation, cf_poses))) + positions = np.array(list(map(lambda x: x.translation, poses.cf_poses))) fig = plt.figure() ax = fig.add_subplot(projection='3d') @@ -197,7 +198,7 @@ def visualize(cf_poses: list[Pose], bs_poses: list[Pose]): ax.scatter(x_cf, y_cf, z_cf) - positions = np.array(list(map(lambda x: x.translation, bs_poses))) + positions = np.array(list(map(lambda x: x.translation, poses.bs_poses.values()))) x_bs = positions[:, 0] y_bs = positions[:, 1] @@ -210,75 +211,30 @@ def visualize(cf_poses: list[Pose], bs_poses: list[Pose]): plt.show() -def write_to_file(name: str, - origin: LhCfPoseSample, - x_axis: list[LhCfPoseSample], - xy_plane: list[LhCfPoseSample], - samples: list[LhCfPoseSample]): - with open(name, 'wb') as handle: - data = (origin, x_axis, xy_plane, samples) - pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) +def load_from_file(name: str) -> LhGeoInputContainerData: + container = LhGeoInputContainer(LhDeck4SensorPositions.positions) + with open(name, 'r', encoding='UTF8') as handle: + container.populate_from_file_yaml(handle) + return container.get_data_copy() -def load_from_file(name: str): - with open(name, 'rb') as handle: - return pickle.load(handle) +def print_solution(solution: LighthouseGeometrySolution): + def _print(msg: str): + print(f' * {msg}') + _print('Solution ready --------------------------------------') + _print(' Base stations at:') + bs_poses = solution.poses.bs_poses + print_base_stations_poses(bs_poses, printer=_print) - -def estimate_geometry(origin: LhCfPoseSample, - x_axis: list[LhCfPoseSample], - xy_plane: list[LhCfPoseSample], - samples: list[LhCfPoseSample]) -> dict[int, Pose]: - """Estimate the geometry of the system based on samples recorded by a Crazyflie""" - matched_samples = [origin] + x_axis + xy_plane + LighthouseSampleMatcher.match(samples, min_nr_of_bs_in_match=2) - initial_guess, cleaned_matched_samples = LighthouseInitialEstimator.estimate( - matched_samples, LhDeck4SensorPositions.positions) - - print('Initial guess base stations at:') - print_base_stations_poses(initial_guess.bs_poses) - - print(f'{len(cleaned_matched_samples)} samples will be used') - visualize(initial_guess.cf_poses, initial_guess.bs_poses.values()) - - solution = LighthouseGeometrySolver.solve(initial_guess, cleaned_matched_samples, LhDeck4SensorPositions.positions) - if not solution.success: - print('Solution did not converge, it might not be good!') - - start_x_axis = 1 - start_xy_plane = 1 + len(x_axis) - origin_pos = solution.cf_poses[0].translation - x_axis_poses = solution.cf_poses[start_x_axis:start_x_axis + len(x_axis)] - x_axis_pos = list(map(lambda x: x.translation, x_axis_poses)) - xy_plane_poses = solution.cf_poses[start_xy_plane:start_xy_plane + len(xy_plane)] - xy_plane_pos = list(map(lambda x: x.translation, xy_plane_poses)) - - print('Raw solution:') - print(' Base stations at:') - print_base_stations_poses(solution.bs_poses) - print(' Solution match per base station:') - for bs_id, value in solution.error_info['bs'].items(): - print(f' {bs_id + 1}: {value}') - - # Align the solution - bs_aligned_poses, transformation = LighthouseSystemAligner.align( - origin_pos, x_axis_pos, xy_plane_pos, solution.bs_poses) - - cf_aligned_poses = list(map(transformation.rotate_translate_pose, solution.cf_poses)) - - # Scale the solution - bs_scaled_poses, cf_scaled_poses, scale = LighthouseSystemScaler.scale_fixed_point(bs_aligned_poses, - cf_aligned_poses, - [REFERENCE_DIST, 0, 0], - cf_aligned_poses[1]) - - print() - print('Final solution:') - print(' Base stations at:') - print_base_stations_poses(bs_scaled_poses) - - visualize(cf_scaled_poses, bs_scaled_poses.values()) - - return bs_scaled_poses + _print(f'Converged: {solution.has_converged}') + _print(f'Progress info: {solution.progress_info}') + _print(f'Progress is ok: {solution.progress_is_ok}') + _print(f'Origin: {solution.is_origin_sample_valid}, {solution.origin_sample_info}') + _print(f'X-axis: {solution.is_x_axis_samples_valid}, {solution.x_axis_samples_info}') + _print(f'XY-plane: {solution.is_xy_plane_samples_valid}, {solution.xy_plane_samples_info}') + _print(f'XYZ space: {solution.xyz_space_samples_info}') + _print(f'General info: {solution.general_failure_info}') + _print(f'Error info: {solution.error_stats}') def upload_geometry(scf: SyncCrazyflie, bs_poses: dict[int, Pose]): @@ -302,11 +258,12 @@ def data_written(_): def estimate_from_file(file_name: str): - origin, x_axis, xy_plane, samples = load_from_file(file_name) - estimate_geometry(origin, x_axis, xy_plane, samples) + container_data = load_from_file(file_name) + solution = LhGeoEstimationManager.estimate_geometry(container_data) + print_solution(solution) -def get_recording(scf: SyncCrazyflie): +def get_recording(scf: SyncCrazyflie) -> LhCfPoseSample: data = None while True: # Infinite loop, will break on valid measurement input('Press return when ready. ') @@ -314,15 +271,17 @@ def get_recording(scf: SyncCrazyflie): measurement = record_angles_average(scf) if measurement is not None: data = measurement + scf.cf.platform.send_user_notification(True) break # Exit the loop if a valid measurement is obtained else: + scf.cf.platform.send_user_notification(False) time.sleep(1) print('Invalid measurement, please try again.') return data -def get_multiple_recordings(scf: SyncCrazyflie): - data = [] +def get_multiple_recordings(scf: SyncCrazyflie) -> list[LhCfPoseSample]: + data: list[LhCfPoseSample] = [] first_attempt = True while True: @@ -341,8 +300,10 @@ def get_multiple_recordings(scf: SyncCrazyflie): print(' Recording...') measurement = record_angles_average(scf) if measurement is not None: + scf.cf.platform.send_user_notification(True) data.append(measurement) else: + scf.cf.platform.send_user_notification(False) time.sleep(1) print('Invalid measurement, please try again.') @@ -353,46 +314,61 @@ def connect_and_estimate(uri: str, file_name: str | None = None): """Connect to a Crazyflie, collect data and estimate the geometry of the system""" print(f'Step 1. Connecting to the Crazyflie on uri {uri}...') with SyncCrazyflie(uri, cf=Crazyflie(rw_cache='./cache')) as scf: + container = LhGeoInputContainer(LhDeck4SensorPositions.positions) + container.enable_auto_save('lh_geo_sessions') + print('Starting geometry estimation thread...') + + def _local_solution_handler(solution: LighthouseGeometrySolution): + print_solution(solution) + if solution.progress_is_ok: + upload_geometry(scf, solution.poses.bs_poses) + print('Geometry uploaded to Crazyflie.') + + thread = LhGeoEstimationManager.SolverThread(container, is_done_cb=_local_solution_handler) + thread.start() + print(' Connected') print('') print('In the 3 following steps we will define the coordinate system.') print('Step 2. Put the Crazyflie where you want the origin of your coordinate system.') - origin = get_recording(scf) + container.set_origin_sample(get_recording(scf)) print(f'Step 3. Put the Crazyflie on the positive X-axis, exactly {REFERENCE_DIST} meters from the origin. ' + - 'This position defines the direction of the X-axis, but it is also used for scaling of the system.') - x_axis = [get_recording(scf)] + 'This position defines the direction of the X-axis, but it is also used for scaling the system.') + container.set_x_axis_sample(get_recording(scf)) - print('Step 4. Put the Crazyflie somehere in the XY-plane, but not on the X-axis.') + print('Step 4. Put the Crazyflie somewhere in the XY-plane, but not on the X-axis.') print('Multiple samples can be recorded if you want to.') - xy_plane = get_multiple_recordings(scf) + container.set_xy_plane_samples(get_multiple_recordings(scf)) print() print('Step 5. We will now record data from the space you plan to fly in and optimize the base station ' + - 'geometry based on this data. Move the Crazyflie around, try to cover all of the space, make sure ' + - 'all the base stations are received and do not move too fast.') - default_time = 20 - recording_time = input(f'Enter the number of seconds you want to record ({default_time} by default), ' + - 'recording starts when you hit enter. ') - recording_time_s = parse_recording_time(recording_time, default_time) - print(' Recording started...') - samples = record_angles_sequence(scf, recording_time_s) - print(' Recording ended') - - if file_name: - write_to_file(file_name, origin, x_axis, xy_plane, samples) - print(f'Wrote data to file {file_name}') - - print('Step 6. Estimating geometry...') - bs_poses = estimate_geometry(origin, x_axis, xy_plane, samples) - print(' Geometry estimated') - - print('Step 7. Upload geometry to the Crazyflie') - input('Press enter to upload geometry. ') - upload_geometry(scf, bs_poses) - print('Geometry uploaded') + 'geometry based on this data. Sample a position by quickly rotating the Crazyflie ' + + 'around the Z-axis. This will trigger a measurement of the base station angles. ') + + def matched_angles_cb(sample: LhCfPoseSample): + print('Position stored') + scf.cf.platform.send_user_notification(True) + container.append_xyz_space_samples([sample]) + scf.cf.platform.send_user_notification() + + def timeout_cb(): + print('Timeout, no angles received. Please try again.') + scf.cf.platform.send_user_notification(False) + angle_reader = LighthouseMatchedSweepAngleReader(scf.cf, matched_angles_cb, timeout_cb=timeout_cb) + + def user_action_cb(): + print('Sampling...') + angle_reader.start(timeout=1.0) + detector = UserActionDetector(scf.cf, cb=user_action_cb) + + detector.start() + input('Press return to terminate the script when all required positions have been sampled.') + + detector.stop() + thread.stop() # Only output errors from the logging framework @@ -406,7 +382,7 @@ def connect_and_estimate(uri: str, file_name: str | None = None): # Set a file name to write the measurement data to file. Useful for debugging file_name = None - # file_name = 'lh_geo_estimate_data.pickle' + file_name = 'lh_geo_estimate_data.yaml' connect_and_estimate(uri, file_name=file_name) diff --git a/examples/lighthouse/upload_geos.py b/examples/lighthouse/upload_geos.py new file mode 100644 index 000000000..df38c0587 --- /dev/null +++ b/examples/lighthouse/upload_geos.py @@ -0,0 +1,23 @@ +import cflib.crtp # noqa +from cflib.crazyflie import Crazyflie +from cflib.crazyflie.syncCrazyflie import SyncCrazyflie +from cflib.utils import uri_helper +from cflib.localization import LighthouseConfigFileManager, LighthouseConfigWriter + + +# Upload a geometry to one or more Crazyflies. + +mgr = LighthouseConfigFileManager() +geos, calibs, type = mgr.read('/path/to/your/geo.yaml') + +uri_list = [ + "radio://0/70/2M/E7E7E7E770" +] + +# Initialize the low-level drivers +cflib.crtp.init_drivers() + +for uri in uri_list: + with SyncCrazyflie(uri, cf=Crazyflie(rw_cache='./cache')) as scf: + writer = LighthouseConfigWriter(scf.cf) + writer.write_and_store_config(data_stored_cb=None, geos=geos, calibs=calibs) diff --git a/test/localization/test_lighthouse_bs_vector.py b/test/localization/test_lighthouse_bs_vector.py index 8d490cd91..9d0eb0e85 100644 --- a/test/localization/test_lighthouse_bs_vector.py +++ b/test/localization/test_lighthouse_bs_vector.py @@ -22,6 +22,7 @@ from test.localization.lighthouse_test_base import LighthouseTestBase import numpy as np +import yaml from cflib.localization import LighthouseBsVector from cflib.localization.lighthouse_bs_vector import LighthouseBsVectors @@ -162,3 +163,71 @@ def test_conversion_to_angle_list(self): # Assert self.assertVectorsAlmostEqual((0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7), actual) + + def test_LighthouseBsVector_equality(self): + # Fixture + vec1 = LighthouseBsVector(0.0, 1.0) + vec2 = LighthouseBsVector(0.1, 1.1) + vec3 = LighthouseBsVector(0.1, 1.1) + + # Test + # Assert + self.assertNotEqual(vec1, vec2) + self.assertEqual(vec2, vec3) + + def test_LighthouseBsVectors_equality(self): + # Fixture + vectors1 = LighthouseBsVectors(( + LighthouseBsVector(0.1, 0.1), + LighthouseBsVector(0.2, 0.2), + LighthouseBsVector(0.3, 0.3), + LighthouseBsVector(0.4, 0.4), + )) + + vectors2 = LighthouseBsVectors(( + LighthouseBsVector(0.0, 0.1), + LighthouseBsVector(0.2, 0.3), + LighthouseBsVector(0.4, 0.5), + LighthouseBsVector(0.6, 0.7), + )) + + vectors3 = LighthouseBsVectors(( + LighthouseBsVector(0.0, 0.1), + LighthouseBsVector(0.2, 0.3), + LighthouseBsVector(0.4, 0.5), + LighthouseBsVector(0.6, 0.7), + )) + + # Test + # Assert + self.assertNotEqual(vectors1, vectors2) + self.assertEqual(vectors2, vectors3) + + def test_LighthouseBsVector_yaml(self): + # Fixture + expected = LighthouseBsVector(0.1, 1.1) + + # Test + yaml_str = yaml.dump(expected) + actual = yaml.load(yaml_str, Loader=yaml.FullLoader) + + # Assert + self.assertTrue(yaml_str.startswith('!LighthouseBsVector')) + self.assertEqual(expected, actual) + + def test_LighthouseBsVectors_yaml(self): + # Fixture + expected = LighthouseBsVectors(( + LighthouseBsVector(0.1, 0.1), + LighthouseBsVector(0.2, 0.2), + LighthouseBsVector(0.3, 0.3), + LighthouseBsVector(0.4, 0.4), + )) + + # Test + yaml_str = yaml.dump(expected) + actual = yaml.load(yaml_str, Loader=yaml.FullLoader) + + # Assert + self.assertTrue(yaml_str.startswith('!LighthouseBsVectors')) + self.assertEqual(expected, actual) diff --git a/test/localization/test_lighthouse_cf_pose_sample.py b/test/localization/test_lighthouse_cf_pose_sample.py new file mode 100644 index 000000000..a6d4e61f1 --- /dev/null +++ b/test/localization/test_lighthouse_cf_pose_sample.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# +# ,---------, ____ _ __ +# | ,-^-, | / __ )(_) /_______________ _____ ___ +# | ( O ) | / __ / / __/ ___/ ___/ __ `/_ / / _ \ +# | / ,--' | / /_/ / / /_/ /__/ / / /_/ / / /_/ __/ +# +------` /_____/_/\__/\___/_/ \__,_/ /___/\___/ +# +# Copyright (C) 2025 Bitcraze AB +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, in version 3. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from test.localization.lighthouse_test_base import LighthouseTestBase + +import yaml + +from cflib.localization.lighthouse_bs_vector import LighthouseBsVector +from cflib.localization.lighthouse_bs_vector import LighthouseBsVectors +from cflib.localization.lighthouse_cf_pose_sample import LhCfPoseSample + + +class TestLhCfPoseSample(LighthouseTestBase): + def setUp(self): + self.vec1 = LighthouseBsVector(0.0, 1.0) + self.vec2 = LighthouseBsVector(0.1, 1.1) + self.vec3 = LighthouseBsVector(0.2, 1.2) + self.vec4 = LighthouseBsVector(0.3, 1.3) + + self.sample1 = LhCfPoseSample({}) + self.sample2 = LhCfPoseSample({3: LighthouseBsVectors([self.vec1, self.vec2, self.vec3, self.vec4])}) + self.sample3 = LhCfPoseSample({3: LighthouseBsVectors([self.vec4, self.vec3, self.vec2, self.vec1])}) + self.sample4 = LhCfPoseSample({3: LighthouseBsVectors([self.vec4, self.vec3, self.vec2, self.vec1])}) + + def test_equality(self): + # Fixture + # Test + # Assert + self.assertEqual(self.sample3, self.sample4) + self.assertNotEqual(self.sample1, self.sample4) + self.assertNotEqual(self.sample2, self.sample4) + + def test_yaml(self): + # Fixture + expected = self.sample3 + + # Test + yaml_str = yaml.dump(expected) + actual = yaml.load(yaml_str, Loader=yaml.FullLoader) + + # Assert + self.assertTrue(yaml_str.startswith('!LhCfPoseSample')) + self.assertEqual(expected, actual) diff --git a/test/localization/test_lighthouse_geometry_solver.py b/test/localization/test_lighthouse_geometry_solver.py index ad2f2fd29..cf4fef036 100644 --- a/test/localization/test_lighthouse_geometry_solver.py +++ b/test/localization/test_lighthouse_geometry_solver.py @@ -22,15 +22,17 @@ from test.localization.lighthouse_fixtures import LighthouseFixtures from test.localization.lighthouse_test_base import LighthouseTestBase +from cflib.localization.lighthouse_cf_pose_sample import LhCfPoseSample +from cflib.localization.lighthouse_geometry_solution import LighthouseGeometrySolution from cflib.localization.lighthouse_geometry_solver import LighthouseGeometrySolver from cflib.localization.lighthouse_initial_estimator import LighthouseInitialEstimator -from cflib.localization.lighthouse_types import LhCfPoseSample from cflib.localization.lighthouse_types import LhDeck4SensorPositions class TestLighthouseGeometrySolver(LighthouseTestBase): def setUp(self): self.fixtures = LighthouseFixtures() + self.solution = LighthouseGeometrySolution() def test_that_two_bs_poses_in_one_sample_are_estimated(self): # Fixture @@ -43,16 +45,17 @@ def test_that_two_bs_poses_in_one_sample_are_estimated(self): bs_id1: self.fixtures.angles_cf_origin_bs1, }), ] + for sample in matched_samples: + sample.augment_with_ippe(LhDeck4SensorPositions.positions) - initial_guess, cleaned_matched_samples = LighthouseInitialEstimator.estimate(matched_samples, - LhDeck4SensorPositions.positions) + initial_guess, cleaned_matched_samples = LighthouseInitialEstimator.estimate(matched_samples, self.solution) # Test - actual = LighthouseGeometrySolver.solve( - initial_guess, cleaned_matched_samples, LhDeck4SensorPositions.positions) + LighthouseGeometrySolver.solve( + initial_guess, cleaned_matched_samples, LhDeck4SensorPositions.positions, self.solution) # Assert - bs_poses = actual.bs_poses + bs_poses = self.solution.poses.bs_poses self.assertPosesAlmostEqual(self.fixtures.BS0_POSE, bs_poses[bs_id0], places=3) self.assertPosesAlmostEqual(self.fixtures.BS1_POSE, bs_poses[bs_id1], places=3) @@ -77,16 +80,17 @@ def test_that_linked_bs_poses_in_multiple_samples_are_estimated(self): bs_id3: self.fixtures.angles_cf2_bs3, }), ] + for sample in matched_samples: + sample.augment_with_ippe(LhDeck4SensorPositions.positions) - initial_guess, cleaned_matched_samples = LighthouseInitialEstimator.estimate(matched_samples, - LhDeck4SensorPositions.positions) + initial_guess, cleaned_matched_samples = LighthouseInitialEstimator.estimate(matched_samples, self.solution) # Test - actual = LighthouseGeometrySolver.solve( - initial_guess, cleaned_matched_samples, LhDeck4SensorPositions.positions) + LighthouseGeometrySolver.solve( + initial_guess, cleaned_matched_samples, LhDeck4SensorPositions.positions, self.solution) # Assert - bs_poses = actual.bs_poses + bs_poses = self.solution.poses.bs_poses self.assertPosesAlmostEqual(self.fixtures.BS0_POSE, bs_poses[bs_id0], places=3) self.assertPosesAlmostEqual(self.fixtures.BS1_POSE, bs_poses[bs_id1], places=3) self.assertPosesAlmostEqual(self.fixtures.BS2_POSE, bs_poses[bs_id2], places=3) diff --git a/test/localization/test_lighthouse_initial_estimator.py b/test/localization/test_lighthouse_initial_estimator.py index b011558b5..cff39835f 100644 --- a/test/localization/test_lighthouse_initial_estimator.py +++ b/test/localization/test_lighthouse_initial_estimator.py @@ -24,29 +24,32 @@ import numpy as np +from cflib.localization.lighthouse_cf_pose_sample import LhCfPoseSample +from cflib.localization.lighthouse_cf_pose_sample import Pose +from cflib.localization.lighthouse_geometry_solution import LighthouseGeometrySolution from cflib.localization.lighthouse_initial_estimator import LighthouseInitialEstimator -from cflib.localization.lighthouse_types import LhCfPoseSample from cflib.localization.lighthouse_types import LhDeck4SensorPositions -from cflib.localization.lighthouse_types import LhException -from cflib.localization.lighthouse_types import Pose class TestLighthouseInitialEstimator(LighthouseTestBase): def setUp(self): self.fixtures = LighthouseFixtures() + self.solution = LighthouseGeometrySolution() - def test_that_one_bs_pose_raises_exception(self): + def test_that_one_bs_pose_failes_solution(self): # Fixture # CF_ORIGIN is used in the first sample and will define the global reference frame bs_id = 3 samples = [ LhCfPoseSample(angles_calibrated={bs_id: self.fixtures.angles_cf_origin_bs0}), ] + self.augment(samples) # Test + LighthouseInitialEstimator.estimate(samples, self.solution) + # Assert - with self.assertRaises(LhException): - LighthouseInitialEstimator.estimate(samples, LhDeck4SensorPositions.positions) + assert self.solution.progress_is_ok is False def test_that_two_bs_poses_in_same_sample_are_found(self): # Fixture @@ -59,9 +62,10 @@ def test_that_two_bs_poses_in_same_sample_are_found(self): bs_id1: self.fixtures.angles_cf_origin_bs1, }), ] + self.augment(samples) # Test - actual, cleaned_samples = LighthouseInitialEstimator.estimate(samples, LhDeck4SensorPositions.positions) + actual, cleaned_samples = LighthouseInitialEstimator.estimate(samples, self.solution) # Assert self.assertPosesAlmostEqual(self.fixtures.BS0_POSE, actual.bs_poses[bs_id0], places=3) @@ -88,9 +92,10 @@ def test_that_linked_bs_poses_in_multiple_samples_are_found(self): bs_id3: self.fixtures.angles_cf2_bs3, }), ] + self.augment(samples) # Test - actual, cleaned_samples = LighthouseInitialEstimator.estimate(samples, LhDeck4SensorPositions.positions) + actual, cleaned_samples = LighthouseInitialEstimator.estimate(samples, self.solution) # Assert self.assertPosesAlmostEqual(self.fixtures.BS0_POSE, actual.bs_poses[bs_id0], places=3) @@ -119,9 +124,10 @@ def test_that_cf_poses_are_estimated(self): bs_id3: self.fixtures.angles_cf2_bs3, }), ] + self.augment(samples) # Test - actual, cleaned_samples = LighthouseInitialEstimator.estimate(samples, LhDeck4SensorPositions.positions) + actual, cleaned_samples = LighthouseInitialEstimator.estimate(samples, self.solution) # Assert self.assertPosesAlmostEqual(self.fixtures.CF_ORIGIN_POSE, actual.cf_poses[0], places=3) @@ -144,9 +150,10 @@ def test_that_the_global_ref_frame_is_used(self): bs_id2: self.fixtures.angles_cf1_bs2, }), ] + self.augment(samples) # Test - actual, cleaned_samples = LighthouseInitialEstimator.estimate(samples, LhDeck4SensorPositions.positions) + actual, cleaned_samples = LighthouseInitialEstimator.estimate(samples, self.solution) # Assert self.assertPosesAlmostEqual( @@ -156,7 +163,7 @@ def test_that_the_global_ref_frame_is_used(self): self.assertPosesAlmostEqual( Pose.from_rot_vec(R_vec=(0.0, 0.0, np.pi), t_vec=(2.0, 1.0, 3.0)), actual.bs_poses[bs_id2], places=3) - def test_that_raises_for_isolated_bs(self): + def test_that_solution_failes_for_isolated_bs(self): # Fixture bs_id0 = 3 bs_id1 = 1 @@ -172,8 +179,51 @@ def test_that_raises_for_isolated_bs(self): bs_id3: self.fixtures.angles_cf2_bs2, }), ] + self.augment(samples) + + # Test + LighthouseInitialEstimator.estimate(samples, self.solution) + + # Assert + assert self.solution.progress_is_ok is False + + def test_that_link_count_is_right(self): + # Fixture + bs_id0 = 3 + bs_id1 = 1 + bs_id2 = 2 + bs_id3 = 4 + samples = [ + LhCfPoseSample(angles_calibrated={ + bs_id0: self.fixtures.angles_cf_origin_bs0, + bs_id1: self.fixtures.angles_cf_origin_bs1, + }), + LhCfPoseSample(angles_calibrated={ + bs_id2: self.fixtures.angles_cf1_bs1, + bs_id3: self.fixtures.angles_cf1_bs2, + }), + LhCfPoseSample(angles_calibrated={ + bs_id0: self.fixtures.angles_cf2_bs0, + bs_id1: self.fixtures.angles_cf2_bs1, + bs_id2: self.fixtures.angles_cf2_bs2, + bs_id3: self.fixtures.angles_cf2_bs3, + }), + ] + self.augment(samples) # Test + LighthouseInitialEstimator.estimate(samples, self.solution) + # Assert - with self.assertRaises(LhException): - LighthouseInitialEstimator.estimate(samples, LhDeck4SensorPositions.positions) + assert self.solution.link_count == { + bs_id0: {bs_id1: 2, bs_id2: 1, bs_id3: 1}, + bs_id1: {bs_id0: 2, bs_id2: 1, bs_id3: 1}, + bs_id2: {bs_id0: 1, bs_id1: 1, bs_id3: 2}, + bs_id3: {bs_id0: 1, bs_id1: 1, bs_id2: 2}, + } + +# helpers + + def augment(self, samples): + for sample in samples: + sample.augment_with_ippe(LhDeck4SensorPositions.positions) diff --git a/test/localization/test_lighthouse_system_aligner.py b/test/localization/test_lighthouse_system_aligner.py index 0e5cea781..e5983164c 100644 --- a/test/localization/test_lighthouse_system_aligner.py +++ b/test/localization/test_lighthouse_system_aligner.py @@ -95,6 +95,29 @@ def test_that_solution_is_de_flipped(self): # Assert self.assertPosesAlmostEqual(expected, actual[bs_id]) + def test_that_solution_is_de_flipped_with_first_bs_under_the_foor(self): + # Fixture + origin = (0.0, 0.0, 0.0) + x_axis = [(-1.0, 0.0, 0.0)] + xy_plane = [(2.0, 1.0, 0.0)] + + bs_poses = {} + + bs_id_1 = 7 + bs_poses[bs_id_1] = Pose.from_rot_vec(t_vec=(0.0, 0.0, -0.1)) + expected_1 = Pose.from_rot_vec(R_vec=(0.0, 0.0, np.pi), t_vec=(0.0, 0.0, -0.1)) + + bs_id_2 = 8 + bs_poses[bs_id_2] = Pose.from_rot_vec(t_vec=(0.0, 0.0, 1.0)) + expected_2 = Pose.from_rot_vec(R_vec=(0.0, 0.0, np.pi), t_vec=(0.0, 0.0, 1.0)) + + # Test + actual, transform = LighthouseSystemAligner.align(origin, x_axis, xy_plane, bs_poses) + + # Assert + self.assertPosesAlmostEqual(expected_1, actual[bs_id_1]) + self.assertPosesAlmostEqual(expected_2, actual[bs_id_2]) + def test_that_is_aligned_for_multiple_points_where_system_is_rotated_and_poins_are_fuzzy(self): # Fixture origin = (0.0, 0.0, 0.0) diff --git a/test/localization/test_lighthouse_types.py b/test/localization/test_lighthouse_types.py index dacc2e27b..dd63285b2 100644 --- a/test/localization/test_lighthouse_types.py +++ b/test/localization/test_lighthouse_types.py @@ -22,6 +22,7 @@ from test.localization.lighthouse_test_base import LighthouseTestBase import numpy as np +import yaml from cflib.localization.lighthouse_types import Pose @@ -96,3 +97,26 @@ def test_rotate_translate_pose_and_back(self): # Assert self.assertPosesAlmostEqual(expected, actual) + + def test_pose_equality(self): + # Fixture + pose1 = Pose.from_rot_vec(R_vec=(1.0, 2.0, 3.0), t_vec=(0.1, 0.2, 0.3)) + pose2 = Pose.from_rot_vec(R_vec=(1.0, 2.0, 3.0), t_vec=(0.1, 0.2, 0.3)) + pose3 = Pose.from_rot_vec(R_vec=(4.0, 5.0, 6.0), t_vec=(7.0, 8.0, 9.0)) + + # Test + # Assert + self.assertEqual(pose1, pose2) + self.assertNotEqual(pose1, pose3) + + def test_pose_yaml(self): + # Fixture + expected = Pose.from_rot_vec(R_vec=(1.0, 2.0, 3.0), t_vec=(0.1, 0.2, 0.3)) + + # Test + yaml_str = yaml.dump(expected) + actual = yaml.load(yaml_str, Loader=yaml.FullLoader) + + # Assert + self.assertTrue(yaml_str.startswith('!Pose')) + self.assertEqual(expected, actual)