Skip to content

Commit 24546b8

Browse files
authored
Merge pull request #119 from hmakelin/use-lightglue
Use LightGlue
2 parents 3e520d1 + afc6467 commit 24546b8

File tree

6 files changed

+74
-42
lines changed

6 files changed

+74
-42
lines changed

docker/mavros/Dockerfile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,9 @@ COPY docker/mavros/gisnav/entrypoint.sh /
9797

9898
RUN chmod +x /entrypoint.sh
9999

100-
# Download LoFTR pretrained weights
101-
RUN python3 -c "from kornia.feature import LoFTR; LoFTR(pretrained='outdoor')"
100+
# Download LightGlue pretrained weights
101+
RUN python3 -c \
102+
"from kornia.feature import LightGlueMatcher, DISK; LightGlueMatcher('disk'); DISK.from_pretrained('depth')"
102103

103104
# Socat for bridging serial port to PX4 container when simulating
104105
RUN apt-get update \

docs/pages/glossary.rst

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,8 @@ correct alternative should be easily inferred from the context they are used in.
249249
navigation filter does not use EKF.
250250

251251
Network
252-
A neural network (a machine learning :term:`model`), such as SuperGlue and
253-
LoFTR
252+
A neural network (a machine learning :term:`model`), such as
253+
:term:`LightGlue`
254254

255255
Node
256256
A :term:`ROS` node.
@@ -725,6 +725,12 @@ relate to external interfaces.
725725
Jupyter notebook
726726
A web based :term:`IDE`: `jupyter.org <https://jupyter.org/>`_
727727

728+
LightGlue
729+
A keypoint matching model: `github.com/cvg/LightGlue <https://github.com/cvg/LightGlue>`_
730+
731+
LoFTR
732+
A keypoint matching model that GISNav used before switching over to :term:`LightGlue`: `zju3dv.github.io/loftr <https://zju3dv.github.io/loftr/>`_
733+
728734
Make
729735
GNU Make, a build automation tool: `gnu.org/software/make/ <https://www.gnu.org/software/make/>`_
730736

gisnav/gisnav/core/gis_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ def _connect_wms(url: str, version: str, timeout: int, poll_rate: float):
372372

373373
@narrow_types
374374
def _bounding_box_with_padding_for_latlon(
375-
self, latitude: float, longitude: float, padding: float = 100.0
375+
self, latitude: float, longitude: float, padding: float = 600.0
376376
):
377377
"""Adds 100 meters of padding to coordinates on both sides"""
378378
meters_in_degree = 111045.0 # at 0 latitude

gisnav/gisnav/core/pose_node.py

Lines changed: 54 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from builtin_interfaces.msg import Time
2020
from cv_bridge import CvBridge
2121
from geometry_msgs.msg import PoseWithCovariance, PoseWithCovarianceStamped
22-
from kornia.feature import LoFTR
22+
from kornia.feature import DISK, LightGlueMatcher, laf_from_center_scale_ori
2323
from rclpy.node import Node
2424
from rclpy.qos import QoSPresetProfiles
2525
from robot_localization.srv import SetPose
@@ -47,8 +47,8 @@
4747
# TODO: make error model and generate covariance matrix dynamically
4848
# Create dummy covariance matrix
4949
_covariance_matrix = np.zeros((6, 6))
50-
np.fill_diagonal(_covariance_matrix, 36) # 3 meter SD = 9 variance
51-
_covariance_matrix[3, 3] = np.radians(15**2) # angle error should be set quite small
50+
np.fill_diagonal(_covariance_matrix, 9) # 3 meter SD = 9 variance
51+
_covariance_matrix[3, 3] = np.radians(5**2) # angle error should be set quite small
5252
_covariance_matrix[4, 4] = _covariance_matrix[3, 3]
5353
_covariance_matrix[5, 5] = _covariance_matrix[3, 3]
5454
_COVARIANCE_LIST = _covariance_matrix.flatten().tolist()
@@ -66,7 +66,7 @@ class PoseNode(Node):
6666
Stricter threshold for shallow matching because mistakes accumulate in VO
6767
"""
6868

69-
CONFIDENCE_THRESHOLD_DEEP_MATCH = 0.7
69+
CONFIDENCE_THRESHOLD_DEEP_MATCH = 0.8
7070
"""Confidence threshold for filtering out bad keypoint matches for
7171
deep matching
7272
"""
@@ -84,8 +84,19 @@ def __init__(self, *args, **kwargs):
8484
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8585

8686
# Initialize DL model for map matching (noisy global position, no drift)
87-
self._model = LoFTR(pretrained="outdoor")
88-
self._model.to(self._device)
87+
self._matcher = (
88+
LightGlueMatcher(
89+
"disk",
90+
params={
91+
"filter_threshold": self.CONFIDENCE_THRESHOLD_DEEP_MATCH,
92+
"depth_confidence": -1,
93+
"width_confidence": -1,
94+
},
95+
)
96+
.to(self._device)
97+
.eval()
98+
)
99+
self._extractor = DISK.from_pretrained("depth").to(self._device)
89100

90101
# Initialize ORB detector and brute force matcher for VO
91102
# (smooth relative position with drift)
@@ -193,18 +204,16 @@ def _get_pose(
193204
r_inv = r.T
194205
camera_optical_position_in_world = -r_inv @ t
195206

196-
# tf_.visualize_camera_position(
197-
# ref.copy(),
198-
# camera_optical_position_in_world,
199-
# f"Camera {'principal point' if shallow_inference else 'position'} "
200-
# f"in {'previous' if shallow_inference else 'world'} frame, "
201-
# f"{'shallow' if shallow_inference else 'deep'} inference",
202-
# )
203-
204-
header = msg.query.header
207+
tf_.visualize_camera_position(
208+
ref.copy(),
209+
camera_optical_position_in_world,
210+
f"Camera {'principal point' if shallow_inference else 'position'} "
211+
f"in {'previous' if shallow_inference else 'world'} frame, "
212+
f"{'shallow' if shallow_inference else 'deep'} inference",
213+
)
205214

206215
pose = tf_.create_pose_msg(
207-
header.stamp,
216+
msg.query.header.stamp,
208217
cast(FrameID, "earth"),
209218
r_inv,
210219
camera_optical_position_in_world,
@@ -348,24 +357,37 @@ def _process(
348357
:return: Tuple of matched query image keypoints, and matched reference image
349358
keypoints
350359
"""
351-
if not shallow_inference: #
352-
if torch.cuda.is_available():
353-
qry_tensor = torch.Tensor(qry[None, None]).cuda() / 255.0
354-
ref_tensor = torch.Tensor(ref[None, None]).cuda() / 255.0
355-
else:
356-
self.get_logger().warning("CUDA not available - using CPU.")
357-
qry_tensor = torch.Tensor(qry[None, None]) / 255.0
358-
ref_tensor = torch.Tensor(ref[None, None]) / 255.0
359-
360-
with torch.no_grad():
361-
results = self._model({"image0": qry_tensor, "image1": ref_tensor})
362-
363-
conf = results["confidence"].cpu().numpy()
364-
good = conf > self.CONFIDENCE_THRESHOLD_DEEP_MATCH
365-
mkp_qry = results["keypoints0"].cpu().numpy()[good, :]
366-
mkp_ref = results["keypoints1"].cpu().numpy()[good, :]
360+
if not shallow_inference:
361+
qry_tensor = torch.Tensor(qry[None, None]).to(self._device) / 255.0
362+
ref_tensor = torch.Tensor(ref[None, None]).to(self._device) / 255.0
363+
qry_tensor = qry_tensor.expand(-1, 3, -1, -1)
364+
ref_tensor = ref_tensor.expand(-1, 3, -1, -1)
365+
366+
with torch.inference_mode():
367+
input = torch.cat([qry_tensor, ref_tensor], dim=0)
368+
# limit number of features to run faster, None means no limit i.e.
369+
# slow but accurate
370+
max_keypoints = 1024 # 4096 # None
371+
feat_qry, feat_ref = self._extractor(
372+
input, max_keypoints, pad_if_not_divisible=True
373+
)
374+
kp_qry, desc_qry = feat_qry.keypoints, feat_qry.descriptors
375+
kp_ref, desc_ref = feat_ref.keypoints, feat_ref.descriptors
376+
lafs_qry = laf_from_center_scale_ori(
377+
kp_qry[None], torch.ones(1, len(kp_qry), 1, 1, device=self._device)
378+
)
379+
lafs_ref = laf_from_center_scale_ori(
380+
kp_ref[None], torch.ones(1, len(kp_ref), 1, 1, device=self._device)
381+
)
382+
dists, match_indices = self._matcher(
383+
desc_qry, desc_ref, lafs_qry, lafs_ref
384+
)
385+
386+
mkp_qry = kp_qry[match_indices[:, 0]].cpu().numpy()
387+
mkp_ref = kp_ref[match_indices[:, 1]].cpu().numpy()
367388

368389
return mkp_qry, mkp_ref
390+
369391
else:
370392
# find the keypoints and descriptors with ORB
371393
kp_qry, desc_qry = self._orb.detectAndCompute(qry, None)

gisnav/launch/params/ekf_node.yaml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,17 @@ robot_localization:
1313
# need predict_to_current_time if required output Hz is higher than input
1414
#predict_to_current_time: true
1515
sensor_timeout: 10.0
16-
pose0_queue_size: 20
17-
pose0_rejection_threshold: 3.0
16+
#pose0_queue_size: 20
17+
pose0_rejection_threshold: 2.0
18+
19+
dynamic_process_noise_covariance: True
20+
transform_timeout: 0.200
1821

1922
use_sim_time: false
2023
two_d_mode: false
2124

22-
smooth_lagged_data: true
23-
history_length: 10.0
25+
#smooth_lagged_data: true
26+
#history_length: 10.0
2427
#print_diagnostics: true
2528

2629
# Fuse absolute pose estimated from map rasters

gisnav/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def parse_package_data(cls, package_file: str) -> PackageData:
115115
# "shapely>=1.8.2",
116116
"OWSLib>=0.25.0",
117117
"torch>=2.1.0",
118-
"kornia==0.6.10",
118+
"kornia==0.7.2", # 0.7.2 for LightGlue and DISK
119119
"transforms3d", # tf_transformations needs this
120120
],
121121
tests_require=["pytest"],

0 commit comments

Comments
 (0)