diff --git a/export.py b/export.py index 8e6342d..0a97818 100644 --- a/export.py +++ b/export.py @@ -3,10 +3,14 @@ import torch +from lightglue_onnx.aliked.aliked import ALIKED from lightglue_onnx import DISK, LightGlue, LightGlueEnd2End, SuperPoint from lightglue_onnx.end2end import normalize_keypoints from lightglue_onnx.utils import load_image, rgb_to_grayscale +from lightglue_onnx.aliked import deform_conv2d_onnx_exporter +deform_conv2d_onnx_exporter.register_deform_conv2d_onnx_op() + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() @@ -22,10 +26,17 @@ def parse_args() -> argparse.Namespace: "--extractor_type", type=str, default="superpoint", - choices=["superpoint", "disk"], + choices=["superpoint", "disk", "aliked"], required=False, help="Type of feature extractor. Supported extractors are 'superpoint' and 'disk'. Defaults to 'superpoint'.", ) + parser.add_argument( + "--aliked_model", + type=str, + default=None, + required=False, + help="The model for aliked extractor.", + ) parser.add_argument( "--extractor_path", type=str, @@ -64,6 +75,7 @@ def parse_args() -> argparse.Namespace: def export_onnx( img_size=512, extractor_type="superpoint", + aliked_model="", extractor_path=None, lightglue_path=None, img0_path="assets/sacre_coeur1.jpg", @@ -76,6 +88,18 @@ def export_onnx( if isinstance(img_size, List) and len(img_size) == 1: img_size = img_size[0] + # Handle aliked desc dim + aliked_desc_dim: dict[str, int] = { + "aliked-t16": 64, + "aliked-n16": 128, + "aliked-n16rot": 128, + "aliked-n32": 128, + } + if extractor_type == "aliked" and aliked_model not in aliked_desc_dim: + raise ValueError( + "The specified aliked model not found. Choose one from -> " + "aliked-t16, aliked-n16, aliked-n16rot, or aliked-n32") + if extractor_path is not None and end2end: raise ValueError( "Extractor will be combined with LightGlue when exporting end-to-end model." @@ -108,6 +132,15 @@ def export_onnx( elif extractor_type == "disk": extractor = DISK(max_num_keypoints=max_num_keypoints).eval() lightglue = LightGlue(extractor_type).eval() + elif extractor_type == "aliked": + # image0 = image0.cuda() + # image1 = image1.cuda() + extractor = ALIKED( + model_name=aliked_model, + device="cpu", + top_k=max_num_keypoints + ) + lightglue = LightGlue(aliked_model).eval() else: raise NotImplementedError( f"LightGlue has not been trained on {extractor_type} features." diff --git a/lightglue_onnx/aliked/__init__.py b/lightglue_onnx/aliked/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lightglue_onnx/aliked/aliked.py b/lightglue_onnx/aliked/aliked.py new file mode 100644 index 0000000..fef0f4c --- /dev/null +++ b/lightglue_onnx/aliked/aliked.py @@ -0,0 +1,235 @@ +import os.path as osp +import time + +import numpy as np +import torch + +from torch import nn +from torchvision.models import resnet +from torchvision.transforms import ToTensor + +from .soft_detect import DKD +from .padder import InputPadder +from .blocks import * + + +ALIKED_CFGS = { + "aliked-t16": { + "c1": 8, + "c2": 16, + "c3": 32, + "c4": 64, + "dim": 64, + "K": 3, + "M": 16, + }, + "aliked-n16": { + "c1": 16, + "c2": 32, + "c3": 64, + "c4": 128, + "dim": 128, + "K": 3, + "M": 16, + }, + "aliked-n16rot": { + "c1": 16, + "c2": 32, + "c3": 64, + "c4": 128, + "dim": 128, + "K": 3, + "M": 16, + }, + "aliked-n32": { + "c1": 16, + "c2": 32, + "c3": 64, + "c4": 128, + "dim": 128, + "K": 3, + "M": 32, + }, +} + + +class ALIKED(nn.Module): + + def __init__( + self, + model_name: str = "aliked-n32", + device: str = "cuda", + top_k: int = -1, # -1 for threshold based mode, >0 for top K mode. + scores_th: float = 0.2, + n_limit: int = 5000, # Maximum number of keypoints to be detected + load_pretrained: bool = True, + ): + super().__init__() + + # get configurations + c1, c2, c3, c4, dim, K, M = [ + v for _, v in ALIKED_CFGS[model_name].items() + ] + conv_types = ["conv", "conv", "dcn", "dcn"] + conv2D = False + mask = False + self.device = device + + # build model + self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) + self.pool4 = nn.AvgPool2d(kernel_size=4, stride=4) + self.norm = nn.BatchNorm2d + self.gate = nn.SELU(inplace=True) + self.block1 = ConvBlock( + 3, c1, self.gate, self.norm, conv_type=conv_types[0] + ) + self.block2 = ResBlock( + c1, + c2, + 1, + nn.Conv2d(c1, c2, 1), + gate=self.gate, + norm_layer=self.norm, + conv_type=conv_types[1], + ) + self.block3 = ResBlock( + c2, + c3, + 1, + nn.Conv2d(c2, c3, 1), + gate=self.gate, + norm_layer=self.norm, + conv_type=conv_types[2], + mask=mask, + device=self.device, + ) + self.block4 = ResBlock( + c3, + c4, + 1, + nn.Conv2d(c3, c4, 1), + gate=self.gate, + norm_layer=self.norm, + conv_type=conv_types[3], + mask=mask, + device=self.device, + ) + self.conv1 = resnet.conv1x1(c1, dim // 4) + self.conv2 = resnet.conv1x1(c2, dim // 4) + self.conv3 = resnet.conv1x1(c3, dim // 4) + self.conv4 = resnet.conv1x1(dim, dim // 4) + self.upsample2 = nn.Upsample( + scale_factor=2, mode="bilinear", align_corners=True + ) + self.upsample4 = nn.Upsample( + scale_factor=4, mode="bilinear", align_corners=True + ) + self.upsample8 = nn.Upsample( + scale_factor=8, mode="bilinear", align_corners=True + ) + self.upsample32 = nn.Upsample( + scale_factor=32, mode="bilinear", align_corners=True + ) + self.score_head = nn.Sequential( + resnet.conv1x1(dim, 8), + self.gate, + resnet.conv3x3(8, 4), + self.gate, + resnet.conv3x3(4, 4), + self.gate, + resnet.conv3x3(4, 1), + ) + self.desc_head = SDDH( + dim, K, M, gate=self.gate, conv2D=conv2D, mask=mask, device=self.device + ) + self.dkd = DKD( + radius=2, top_k=top_k, scores_th=scores_th, n_limit=n_limit + ) + + # load pretrained + if load_pretrained: + url = f"https://raw.githubusercontent.com/ajuric/aliked-tensorrt/main/models/{model_name}.pth" + print(f"loading {url}") + state_dict = torch.hub.load_state_dict_from_url( + url, map_location="cpu") + self.load_state_dict(state_dict, strict=True) + self.to(device) + self.eval() + + def extract_dense_map(self, image): + # Pads images such that dimensions are divisible by + div_by = 2**5 + padder = InputPadder(image.shape[-2], image.shape[-1], div_by) + image = padder.pad(image) + + # ================================== feature encoder + x1 = self.block1(image) # B x c1 x H x W + x2 = self.pool2(x1) + x2 = self.block2(x2) # B x c2 x H/2 x W/2 + x3 = self.pool4(x2) + x3 = self.block3(x3) # B x c3 x H/8 x W/8 + x4 = self.pool4(x3) + x4 = self.block4(x4) # B x dim x H/32 x W/32 + # ================================== feature aggregation + x1 = self.gate(self.conv1(x1)) # B x dim//4 x H x W + x2 = self.gate(self.conv2(x2)) # B x dim//4 x H//2 x W//2 + x3 = self.gate(self.conv3(x3)) # B x dim//4 x H//8 x W//8 + x4 = self.gate(self.conv4(x4)) # B x dim//4 x H//32 x W//32 + x2_up = self.upsample2(x2) # B x dim//4 x H x W + x3_up = self.upsample8(x3) # B x dim//4 x H x W + x4_up = self.upsample32(x4) # B x dim//4 x H x W + x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1) + # ================================== score head + score_map = torch.sigmoid(self.score_head(x1234)) + feature_map = torch.nn.functional.normalize(x1234, p=2, dim=1) + + # Unpads images + feature_map = padder.unpad(feature_map) + score_map = padder.unpad(score_map) + + return feature_map, score_map + + def forward(self, image): + torch.cuda.synchronize() + # t0 = time.time() + feature_map, score_map = self.extract_dense_map(image) + keypoints, kptscores, scoredispersitys = self.dkd(score_map) + descriptors, offsets = self.desc_head(feature_map, keypoints) + torch.cuda.synchronize() + # t1 = time.time() + + # return { + # "keypoints": keypoints, # B N 2 + # "descriptors": descriptors, # B N D + # "scores": kptscores, # B N + # # 'score_dispersity': scoredispersitys, + # # 'score_map': score_map, # Bx1xHxW + # # 'time': t1-t0, + # } + return keypoints, kptscores, descriptors + + def warmup(self, image: np.ndarray, num_iterations: int = 3) -> None: + print("Starting warm-up ...") + for _ in range(num_iterations): + self.run(image) + print("Warm-up done!") + + def run(self, img_rgb): + + img_tensor = ToTensor()(img_rgb) + # img_tensor = img_tensor.to(self.device).unsqueeze_(0).half() + img_tensor = img_tensor.to(self.device).unsqueeze_(0) + + with torch.no_grad(): + keypoints, descriptors, scores = self.forward(img_tensor) + + keypoints = keypoints[0] + _, _, h, w = img_tensor.shape + wh = torch.tensor([w - 1, h - 1], device=keypoints.device) + keypoints = wh * (keypoints + 1) / 2 + + return { + "keypoints": keypoints.cpu().numpy(), # N 2 + "scores": scores[0].cpu().numpy(), # B N D + "descriptors": descriptors[0].cpu().numpy(), # N D + } diff --git a/lightglue_onnx/aliked/blocks.py b/lightglue_onnx/aliked/blocks.py new file mode 100644 index 0000000..d5bdb08 --- /dev/null +++ b/lightglue_onnx/aliked/blocks.py @@ -0,0 +1,400 @@ +from typing import Optional, Callable + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.modules.utils import _pair +import torchvision +# from custom_ops import get_patches + + +def get_patches( + input_map: torch.Tensor, locations: torch.Tensor, patch_size: int = 3, device: str = "" +): + # TODO: check shapes of input params. + + # Convert patch center locations to integer to support indexing. + locations_long = locations.long() + + # Create grid for patch, for indexing. + window_low = 0 + window_high = patch_size + y_grid, x_grid = torch.meshgrid( + torch.arange(window_low, window_high), + torch.arange(window_low, window_high), + ) + + x_coords = locations_long[:, 0] - 1 + y_coords = locations_long[:, 1] - 1 + x_indices = x_coords.unsqueeze(-1).unsqueeze(-1) + x_grid.to(device) + y_indices = y_coords.unsqueeze(-1).unsqueeze(-1) + y_grid.to(device) + + input_shape = input_map.shape + x_indices = torch.clip(x_indices, 0, input_shape[2] - 1) + y_indices = torch.clip(y_indices, 0, input_shape[1] - 1) + patches = input_map[:, y_indices, x_indices] + patches = patches.permute(1, 0, 2, 3) + return patches + + +class DeformableConv2d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + mask=False, + device="cpu" + ): + super(DeformableConv2d, self).__init__() + + self.padding = padding + self.mask = mask + self.device = device + self.channel_num = ( + 3 * kernel_size * kernel_size + if mask + else 2 * kernel_size * kernel_size + ) + self.offset_conv = nn.Conv2d( + in_channels, + self.channel_num, + kernel_size=kernel_size, + stride=stride, + padding=self.padding, + bias=True, + ) + + self.regular_conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=self.padding, + bias=bias, + ) + + def forward(self, x): + h, w = x.shape[2:] + max_offset = max(h, w) / 4.0 + max_offset = torch.tensor(max_offset, device=self.device) + + out = self.offset_conv(x) + if self.mask: + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + else: + offset = out + mask = None + offset = offset.clamp(-max_offset, max_offset) + x = torchvision.ops.deform_conv2d( + input=x, + offset=offset, + weight=self.regular_conv.weight, + bias=self.regular_conv.bias, + padding=self.padding, + mask=mask, + ) + return x + + +def get_conv( + inplanes, + planes, + kernel_size=3, + stride=1, + padding=1, + bias=False, + conv_type="conv", + mask=False, + device="", +): + if conv_type == "conv": + conv = nn.Conv2d( + inplanes, + planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + ) + elif conv_type == "dcn": + conv = DeformableConv2d( + inplanes, + planes, + kernel_size=kernel_size, + stride=stride, + padding=_pair(padding), + bias=bias, + mask=mask, + device=device, + ) + else: + raise TypeError + return conv + + +class ConvBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + gate: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + conv_type: str = "conv", + mask: bool = False, + ): + super().__init__() + if gate is None: + self.gate = nn.ReLU(inplace=True) + else: + self.gate = gate + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self.conv1 = get_conv( + in_channels, + out_channels, + kernel_size=3, + conv_type=conv_type, + mask=mask, + ) + self.bn1 = norm_layer(out_channels) + self.conv2 = get_conv( + out_channels, + out_channels, + kernel_size=3, + conv_type=conv_type, + mask=mask, + ) + self.bn2 = norm_layer(out_channels) + + def forward(self, x): + x = self.gate(self.bn1(self.conv1(x))) # B x in_channels x H x W + x = self.gate(self.bn2(self.conv2(x))) # B x out_channels x H x W + return x + + +# modified based on torchvision\models\resnet.py#27->BasicBlock +class ResBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + gate: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + conv_type: str = "conv", + mask: bool = False, + device: str = "cpu", + ) -> None: + super(ResBlock, self).__init__() + if gate is None: + self.gate = nn.ReLU(inplace=True) + else: + self.gate = gate + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError( + "ResBlock only supports groups=1 and base_width=64" + ) + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in ResBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = get_conv( + inplanes, planes, kernel_size=3, conv_type=conv_type, mask=mask, device=device + ) + self.bn1 = norm_layer(planes) + self.conv2 = get_conv( + planes, planes, kernel_size=3, conv_type=conv_type, mask=mask, device=device + ) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.gate(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.gate(out) + + return out + + +class SDDH(nn.Module): + def __init__( + self, + dims: int, + kernel_size: int = 3, + n_pos: int = 8, + gate=nn.ReLU(), + conv2D=False, + mask=False, + device="cpu", + ): + super(SDDH, self).__init__() + self.kernel_size = kernel_size + self.n_pos = n_pos + self.conv2D = conv2D + self.mask = mask + self.device = device + # self.get_patches_func = get_patches.apply + + # estimate offsets + self.channel_num = 3 * n_pos if mask else 2 * n_pos + self.offset_conv = nn.Sequential( + nn.Conv2d( + dims, + self.channel_num, + kernel_size=kernel_size, + stride=1, + padding=0, + bias=True, + ), + gate, + nn.Conv2d( + self.channel_num, + self.channel_num, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ), + ) + + # sampled feature conv + self.sf_conv = nn.Conv2d( + dims, dims, kernel_size=1, stride=1, padding=0, bias=False + ) + + # convM + if not conv2D: + # deformable desc weights + agg_weights = torch.nn.Parameter(torch.rand(n_pos, dims, dims)) + self.register_parameter("agg_weights", agg_weights) + else: + self.convM = nn.Conv2d( + dims * n_pos, + dims, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + + def forward(self, x, keypoints): + # x: [B,C,H,W] + # keypoints: list, [[N_kpts,2], ...] (w,h) + b, c, h, w = x.shape + keypoints = list(torch.unbind(keypoints, dim=0)) + wh = torch.tensor([[w - 1, h - 1]], device=x.device) + max_offset = max(h, w) / 4.0 + max_offset = torch.tensor(max_offset, device=self.device) + + offsets = [] + descriptors = [] + # get offsets for each keypoint + for ib in range(b): + xi, kptsi = x[ib], keypoints[ib] + kptsi_wh = (kptsi / 2 + 0.5) * wh + N_kpts = len(kptsi) + + # xi.shape: (128, 376, 1241) - (D, H, W) + # kpts_wh.shape: (1441, 2) - (N, 2), x,y -> W, H + # kernel_size: 3 - int + # patch.shape: (N, D, kernel_size, kernel_size) + + if self.kernel_size > 1: + # patch = self.get_patches_func( + # xi, kptsi_wh.long(), self.kernel_size + # ) # [N_kpts, C, K, K] + # sad treba usporediti nakom permute!!! + patch = get_patches( + xi, kptsi_wh.long(), self.kernel_size, device=self.device + ) # [N_kpts, C, K, K] + else: + kptsi_wh_long = kptsi_wh.long() + patch = ( + xi[:, kptsi_wh_long[:, 1], kptsi_wh_long[:, 0]] + .permute(1, 0) + .reshape(N_kpts, c, 1, 1) + ) + + offset = self.offset_conv(patch).clamp( + -max_offset, max_offset + ) # [N_kpts, 2*n_pos, 1, 1] + if self.mask: + offset = ( + offset[:, :, 0, 0] + .view(N_kpts, 3, self.n_pos) + .permute(0, 2, 1) + ) # [N_kpts, n_pos, 3] + offset = offset[:, :, :-1] # [N_kpts, n_pos, 2] + mask_weight = torch.sigmoid( + offset[:, :, -1] + ) # [N_kpts, n_pos] + else: + offset = ( + offset[:, :, 0, 0] + .view(N_kpts, 2, self.n_pos) + .permute(0, 2, 1) + ) # [N_kpts, n_pos, 2] + offsets.append(offset) # for visualization + + # get sample positions + pos = kptsi_wh.unsqueeze(1) + offset # [N_kpts, n_pos, 2] + pos = 2.0 * pos / wh[None] - 1 + pos = pos.reshape(1, N_kpts * self.n_pos, 1, 2) + + # sample features + features = F.grid_sample( + xi.unsqueeze(0), pos, mode="bilinear", align_corners=True + ) # [1,C,(N_kpts*n_pos),1] + features = features.reshape(c, N_kpts, self.n_pos, 1).permute( + 1, 0, 2, 3 + ) # [N_kpts, C, n_pos, 1] + if self.mask: + features = torch.einsum("ncpo,np->ncpo", features, mask_weight) + + features = torch.selu_(self.sf_conv(features)).squeeze( + -1 + ) # [N_kpts, C, n_pos] + # convM + if not self.conv2D: + descs = torch.einsum( + "ncp,pcd->nd", features, self.agg_weights + ) # [N_kpts, C] + else: + features = features.reshape(N_kpts, -1)[ + :, :, None, None + ] # [N_kpts, C*n_pos, 1, 1] + descs = self.convM(features).squeeze() # [N_kpts, C] + + # normalize + descs = F.normalize(descs, p=2.0, dim=1) + descriptors.append(descs) + + descriptors = torch.stack(descriptors, dim=0) + offsets = torch.stack(offsets, dim=0) + + return descriptors, offsets diff --git a/lightglue_onnx/aliked/deform_conv2d_onnx_exporter.py b/lightglue_onnx/aliked/deform_conv2d_onnx_exporter.py new file mode 100644 index 0000000..5164355 --- /dev/null +++ b/lightglue_onnx/aliked/deform_conv2d_onnx_exporter.py @@ -0,0 +1,710 @@ +"""This module adds ONNX conversion of `deform_conv2d`. + +This module implements Deformable Convolution v2, +described in a paper, `Deformable ConvNets v2: More Deformable, Better Results +`, using ONNX operators. +The implementation is straightforward, but may not be very efficient. + +This exporter requires opset version 12 to support the following operators: + - Clip: + It can accept tensor(int64) from version 12. + - GatherND: + It can support batch_dims from version 12. +""" + +import torch +from torch.onnx import register_custom_op_symbolic +from torch.onnx import symbolic_helper as sym_help +try: + from torch.onnx._type_utils import JitScalarType +except ImportError: + JitScalarType = None + +__all__ = ["register_deform_conv2d_onnx_op"] + +onnx_opset_version = 12 + + +def add(g, lhs, rhs): + return g.op("Add", lhs, rhs) + + +def sub(g, lhs, rhs): + return g.op("Sub", lhs, rhs) + + +def mul(g, lhs, rhs): + return g.op("Mul", lhs, rhs) + + +def reshape(g, x, shape): + if isinstance(shape, list): + shape = tensor(g, shape, dtype=torch.int64) + return g.op("Reshape", x, shape) + + +def slice(g, x, axes, starts, ends, *, steps=None): + axes = tensor(g, axes, dtype=torch.int64) + starts = tensor(g, starts, dtype=torch.int64) + ends = tensor(g, ends, dtype=torch.int64) + if steps is not None: + steps = tensor(g, steps, dtype=torch.int64) + return g.op("Slice", x, starts, ends, axes, steps) + else: + return g.op("Slice", x, starts, ends, axes) + + +def unsqueeze(g, input, dims): + return sym_help._unsqueeze_helper(g, input, axes_i=dims) + + +def get_tensor_dim_size(tensor, dim): + tensor_dim_size = sym_help._get_tensor_dim_size(tensor, dim) + if tensor_dim_size == None and (dim == 2 or dim == 3): + import typing + from torch import _C + + x_type = typing.cast(_C.TensorType, tensor.type()) + x_strides = x_type.strides() + + tensor_dim_size = x_strides[2] if dim == 3 else x_strides[1] // x_strides[2] + elif tensor_dim_size == None and (dim == 0): + import typing + from torch import _C + + x_type = typing.cast(_C.TensorType, tensor.type()) + x_strides = x_type.strides() + tensor_dim_size = x_strides[3] + return tensor_dim_size + + +def tensor(g, value, dtype): + return g.op("Constant", value_t=torch.tensor(value, dtype=dtype)) + + +def calculate_p_0(dcn_params): + """ + Calculate p_0 value in equation (1) in the paper. + + Args: + dcn_params: parameters for deform_conv2d. + + Returns: + torch.Tensor[1, 1, kernel_area_size, 2, out_h, out_w] + """ + h = dcn_params["out_h"] + w = dcn_params["out_w"] + stride_h = dcn_params["stride_h"] + stride_w = dcn_params["stride_w"] + K = dcn_params["kernel_area_size"] + additional_pad_h = dcn_params["additional_pad_h"] + additional_pad_w = dcn_params["additional_pad_w"] + + p_0_y, p_0_x = torch.meshgrid(torch.arange(0, h * stride_h, stride_h), + torch.arange(0, w * stride_w, stride_w)) + p_0_y = p_0_y.view(1, 1, 1, 1, h, w).repeat(1, 1, K, 1, 1, 1) + p_0_y += additional_pad_h + p_0_x = p_0_x.view(1, 1, 1, 1, h, w).repeat(1, 1, K, 1, 1, 1) + p_0_x += additional_pad_w + return torch.cat([p_0_y, p_0_x], dim=3) + + +def calculate_p_k(dcn_params): + """ + Calculate p_k value in equation (1) in the paper. + + Args: + dcn_params: parameters for deform_conv2d. + + Returns: + torch.Tensor[1, 1, kernel_area_size, 2, 1, 1] + """ + kernel_h = dcn_params["kernel_h"] + kernel_w = dcn_params["kernel_w"] + dilation_h = dcn_params["dilation_h"] + dilation_w = dcn_params["dilation_w"] + K = dcn_params["kernel_area_size"] + + p_k_y, p_k_x = torch.meshgrid( + torch.arange(0, kernel_h * dilation_h, step=dilation_h), + torch.arange(0, kernel_w * dilation_w, step=dilation_w), + ) + p_k_y = p_k_y.reshape(1, 1, K, 1, 1, 1) + p_k_x = p_k_x.reshape(1, 1, K, 1, 1, 1) + return torch.cat([p_k_y, p_k_x], dim=3) + + +def calculate_p(g, dcn_params, offset): + """ + Calculate p_0 + p_k + Delta(p_k) in equation (1) in the paper. + + Args: + g: graph object. + dcn_params: parameters for deform_conv2d. + offset: Delta(p_k) in the paper. + The shape is (b, group, K, 2, out_h, out_w). + + Returns: + The shape is (b, group, K, 2, out_h, out_w). + """ + b = dcn_params["batch"] + K = dcn_params["kernel_area_size"] + h = dcn_params["out_h"] + w = dcn_params["out_w"] + group = dcn_params["n_offset_grps"] + offset_dtype = dcn_params["offset_dtype_pytorch"] + + offset = reshape(g, offset, [b, group, K, 2, h, w]) + + p_0 = calculate_p_0(dcn_params) + p_k = calculate_p_k(dcn_params) + p = p_0 + p_k + p = add(g, tensor(g, p.tolist(), dtype=offset_dtype), offset) + # => p.shape is (b, group, K, 2, h, w) + return p + + +def calculate_p_floor(g, dcn_params, p): + """ + Calculate floor of p. + + Args: + g: graph object. + dcn_params: parameters for deform_conv2d. + p: Coords for sampling points of DCN. + The shape is (b, group, K, 2, out_h, out_w). + + Returns: + The shape is (b, group, K, 2, out_h, out_w). + Note that the data type is not integer but float. + """ + p_floor = g.op("Floor", p) + return p_floor + + +def calculate_p_tlbr(g, dcn_params, p_floor): + """ + Calculate floor and ceil of p. + + Args: + g: graph object. + dcn_params: parameters for deform_conv2d. + p_floor: Floored coords for sampling points of DCN. + The shape is (b, group, K, 2, out_h, out_w). + + Returns: + A dict, {"t": p_t, "l", p_l, "b": p_b, "r": p_r}, which contains + "t"op, "l"eft, "b"ottom, and "r"ight coordinates around p. + The shape of p_t, ..., p_r is (b, group, K, 1, out_h, out_w). + """ + h = dcn_params["in_h"] + w = dcn_params["in_w"] + index_dtype_onnx = dcn_params["index_dtype_onnx"] + index_dtype_pytorch = dcn_params["index_dtype_pytorch"] + + p_floor = g.op("Cast", p_floor, to_i=index_dtype_onnx) + one = tensor(g, 1, dtype=index_dtype_pytorch) + + p_t = slice(g, p_floor, [3], [0], [1]) + p_l = slice(g, p_floor, [3], [1], [2]) + p_b = add(g, p_t, one) + p_r = add(g, p_l, one) + + # Clip out-of-bounds coords. + # Clipped coords point to padding area, which is filled with 0. + p_t = g.op("Clip", p_t, tensor(g, 0, dtype=index_dtype_pytorch), + tensor(g, h - 1, dtype=index_dtype_pytorch)) + p_l = g.op("Clip", p_l, tensor(g, 0, dtype=index_dtype_pytorch), + tensor(g, w - 1, dtype=index_dtype_pytorch)) + p_b = g.op("Clip", p_b, tensor(g, 0, dtype=index_dtype_pytorch), + tensor(g, h - 1, dtype=index_dtype_pytorch)) + p_r = g.op("Clip", p_r, tensor(g, 0, dtype=index_dtype_pytorch), + tensor(g, w - 1, dtype=index_dtype_pytorch)) + return { + "t": p_t, + "l": p_l, + "b": p_b, + "r": p_r, + } + + +def calculate_weight(g, dcn_params, p, p_floor): + """ + Calculate weight value for bilinear interpolation. + + Args: + g: graph object. + dcn_params: parameters for deform_conv2d. + p: Coords for sampling points. + The shape is (b, group, K, 2, out_h, out_w). + p_floor: Floored coords for sampling points. + The shape is (b, group, K, 2, out_h, out_w). + + Returns: + A dict, {"tl": weight_tl, "br": weight_br, ..., "tr": weight_tr}, + which contains weights for "t"op-"l"eft, "b"ottom-"r"ight, .... + The shape of weight_tl is (b, group, 1, K, out_h, out_w). + """ + b = dcn_params["batch"] + group = dcn_params["n_offset_grps"] + h = dcn_params["out_h"] + w = dcn_params["out_w"] + K = dcn_params["kernel_area_size"] + offset_dtype = dcn_params["offset_dtype_pytorch"] + + one = tensor(g, 1, dtype=offset_dtype) + + diff = sub(g, p, p_floor) + diff_y = slice(g, diff, [3], [0], [1]) + diff_x = slice(g, diff, [3], [1], [2]) + diff_y_inv = sub(g, one, diff_y) + diff_x_inv = sub(g, one, diff_x) + + # bilinear kernel (b, group, K, 1, h, w) + # (1 - (p_x - p_l)) * (1 - (p_y - p_t)) + weight_tl = mul(g, diff_x_inv, diff_y_inv) + # (p_x - p_l) * (p_y - p_t) + weight_br = mul(g, diff_x, diff_y) + # (1 - (p_x - p_l)) * (p_y - p_t) + weight_bl = mul(g, diff_x_inv, diff_y) + # (p_x - p_l) * (1 - (p_y - p_t)) + weight_tr = mul(g, diff_x, diff_y_inv) + + weights = { + "tl": weight_tl, + "br": weight_br, + "bl": weight_bl, + "tr": weight_tr, + } + weights = { + key: reshape(g, weight, [b, group, 1, K, h, w]) + for key, weight in weights.items() + } + return weights + + +def reshape_input_for_gather_elements(g, dcn_params, input): + """ + Reshape input for gather_elements function. + + Even if no padding is specified, 1 padding is always added + to ensure that out-of-bounds index can be handled correctly. + + This function also transpose input tensor, so that "GatherND" + can easily gather all data in a channel. + + Args: + g: graph object. + dcn_params: parameters for deform_conv2d. + input: input tensor. + The shape is (b, in_ch, in_h, in_w) + + Returns: + The shape is (b, group, ch_per_group, in_h, in_w). + """ + b = dcn_params["batch"] + group = dcn_params["n_offset_grps"] + ch = dcn_params["in_ch_per_group"] + in_h = dcn_params["in_h"] + in_w = dcn_params["in_w"] + pad_h = dcn_params["padding_h"] + pad_w = dcn_params["padding_w"] + additional_pad_h = dcn_params["additional_pad_h"] + additional_pad_w = dcn_params["additional_pad_w"] + + pad_size = [ + 0, + 0, + (pad_h + additional_pad_h), + (pad_w + additional_pad_w), + 0, + 0, + (pad_h + additional_pad_h), + (pad_w + additional_pad_w), + ] + pad = tensor(g, pad_size, dtype=torch.int64) + input = g.op("Pad", input, pad, mode_s="constant") + input = reshape(g, input, [b, group, ch, in_h, in_w]) + return input + + +def gather_elements(g, dcn_params, input, p_y, p_x): + """ + Gather elements specified by p_y and p_x using GatherElements operator. + + Args: + g: graph object. + dcn_params: parameters for deform_conv2d. + input: input tensor. + The shape is (b, group, ch_per_group, in_h, in_w). + p_y: y coordinates of sampling points. + The shape is (b, group, K, 1, out_h, out_w). + p_x: x coordinates of sampling points. + The shape is (b, group, K, 1, out_h, out_w). + + Returns: + The shape is (b, group, ch_per_group, K, out_h, out_w). + """ + b = dcn_params["batch"] + group = dcn_params["n_offset_grps"] + ch = dcn_params["in_ch_per_group"] + in_h = dcn_params["in_h"] + in_w = dcn_params["in_w"] + out_h = dcn_params["out_h"] + out_w = dcn_params["out_w"] + K = dcn_params["kernel_area_size"] + index_dtype_pytorch = dcn_params["index_dtype_pytorch"] + + p_y = reshape(g, p_y, [b, group, 1, K * out_h * out_w]) + p_x = reshape(g, p_x, [b, group, 1, K * out_h * out_w]) + p_y = g.op("Mul", p_y, tensor(g, in_w, dtype=index_dtype_pytorch)) + index = g.op("Add", p_y, p_x) + shape = [b, group, ch, K * out_h * out_w] + index = g.op("Expand", index, tensor(g, shape, dtype=torch.int64)) + + input = reshape(g, input, [b, group, ch, in_h * in_w]) + + v = g.op("GatherElements", input, index, axis_i=3) + # => v.shape is (b, group, ch_per_group, K * out_h * out_w) + v = reshape(g, v, [b, group, ch, K, out_h, out_w]) + + return v + + +def gather_nd(g, dcn_params, input, p_y, p_x): + """ + Gather elements specified by p_y and p_x using GatherND. + + Args: + g: graph object. + dcn_params: parameters for deform_conv2d. + input: input tensor. + The shape is (b, group, ch_per_group, in_h, in_w). + p_y: y coordinates of sampling points. + The shape is (b, group, K, 1, out_h, out_w). + p_x: x coordinates of sampling points. + The shape is (b, group, K, 1, out_h, out_w). + + Returns: + The shape is (b, group, ch_per_group, K, out_h, out_w). + """ + b = dcn_params["batch"] + group = dcn_params["n_offset_grps"] + ch = dcn_params["in_ch_per_group"] + out_h = dcn_params["out_h"] + out_w = dcn_params["out_w"] + K = dcn_params["kernel_area_size"] + + p_y = reshape(g, p_y, [b, group, K * out_h * out_w, 1]) + p_x = reshape(g, p_x, [b, group, K * out_h * out_w, 1]) + index = g.op("Concat", p_y, p_x, axis_i=3) + # => index.shape is (b, group, K * out_h * out_w, 2) + + input = g.op("Transpose", input, perm_i=[0, 1, 3, 4, 2]) + # => input.shape is (b, group, in_h, in_w, ch_per_group) + v = g.op("GatherND", input, index, batch_dims_i=2) + # => v.shape is (b, group, K * out_h * out_w, ch) + if dcn_params["option"]["enable_openvino_patch"]: + # OpenVINO 2021.4 has a bug related to shape of the output of GatherND. + v = reshape(g, v, [b, group, K * out_h * out_w, ch]) + v = g.op("Transpose", v, perm_i=[0, 1, 3, 2]) + v = reshape(g, v, [b, group, ch, K, out_h, out_w]) + return v + + +def gather_elements_tlbr(g, dcn_params, input, p_tlbr): + """ + Gather elements specified by p_tlbr. + + Args: + g: graph object. + dcn_params: parameters for deform_conv2d. + input: input tensor. + The shape is (b, group, ch_per_group, in_h, in_w). + p_tlbr: A dict, {"t": p_t, "l", p_l, "b": p_b, "r": p_r}, + which contains "t"op, "l"eft, "b"ottom, and "r"ight + coordinates around p. + The shape of p_t, ..., p_r is (b, group, K, 1, out_h, out_w). + + Returns: + A dict, {"tl": v_tl, "br": v_br, ..., "tr": v_tr}, which contains + gathred elements. + The shape of v_tl is (b, group, ch_per_group, K, out_h, out_w). + """ + tlbr = ["tl", "br", "bl", "tr"] + v_tlbr = {} + for key in tlbr: + key_y = key[0] # "t" or "b" + key_x = key[1] # "l" or "r" + p_y = p_tlbr[key_y] + p_x = p_tlbr[key_x] + if dcn_params["option"]["use_gathernd"]: + v = gather_nd(g, dcn_params, input, p_y, p_x) + else: + v = gather_elements(g, dcn_params, input, p_y, p_x) + v_tlbr[key] = v + return v_tlbr + + +def calculate_weighted_sum(g, dcn_params, v_tlbr, weight_tlbr): + """ + Calculate sum of weighted tensors. + + Args: + g: graph object. + dcn_params: parameters for deform_conv2d. + v_tlbr: a dict, {"tl": v_tl, "br": v_br, ..., "tr": v_tr}, which + contains gathred elements. + The shape of v_tl is (b, group, ch_per_group, K, out_h, out_w). + weight_tlbr: a dict, {"tl": weight_tl, "br": weight_br, ...}, + which contains weights for "t"op-"l"eft, "b"ottom-"r"ight, .... + The shape of weight_tl is (b, group, 1, K, out_h, out_w). + + Returns: + The shape is (b, group, ch_per_group, K, out_h, out_w). + """ + weighted_v_list = [mul(g, weight_tlbr[key], v_tlbr[key]) for key in v_tlbr] + v = g.op("Sum", *weighted_v_list) + return v + + +def apply_mask(g, dcn_params, v, mask): + """ + Apply mask tensor. + + Args: + g: graph object. + dcn_params: parameters for deform_conv2d. + v: input tensor. + The shape is (b, group, ch_per_group, K, out_h, out_w). + mask: mask tensor. + The shape is (b, group * K, out_h, out_w). + + Returns: + The shape is (b, group, ch_per_group, K, out_h, out_w). + """ + b = dcn_params["batch"] + group = dcn_params["n_offset_grps"] + out_h = dcn_params["out_h"] + out_w = dcn_params["out_w"] + K = dcn_params["kernel_area_size"] + + mask = reshape(g, mask, [b, group, 1, K, out_h, out_w]) + v = mul(g, v, mask) + return v + + +def reshape_v_for_conv(g, dcn_params, v): + """ + Reshape v for convolution. + + Args: + g: graph object. + dcn_params: parameters for deform_conv2d. + v: a reshaped tensor. + The shape is (b, group, ch_per_group, K, out_h, out_w). + + Returns: + The shape is (b, in_ch, out_h * kernel_h, out_w * kernel_w). + """ + b = dcn_params["batch"] + h = dcn_params["out_h"] + w = dcn_params["out_w"] + ch = dcn_params["in_ch"] + kernel_h = dcn_params["kernel_h"] + kernel_w = dcn_params["kernel_w"] + + v = reshape(g, v, [b, ch, kernel_h, kernel_w, h, w]) + v = g.op("Transpose", v, perm_i=[0, 1, 4, 2, 5, 3]) + return reshape(g, v, [b, ch, h * kernel_h, w * kernel_w]) + + +def apply_conv(g, dcn_params, v, weight): + """ + Apply convolution. + + Args: + g: graph object. + dcn_params: parameters for deform_conv2d. + v: input tensor. + The shape is (b, in_ch, out_h * kernel_h, out_w * kernel_w). + weight: weight for convolution. + The shape is (out_ch, ch_per_group, kernel_h, kernel_w). + + Returns: + The shape is (b, out_ch, out_h, out_w). + """ + weight_groups = dcn_params["n_weight_grps"] + kernel_h = dcn_params["kernel_h"] + kernel_w = dcn_params["kernel_w"] + + v = g.op("Conv", + v, + weight, + group_i=weight_groups, + kernel_shape_i=[kernel_h, kernel_w], + strides_i=[kernel_h, kernel_w]) + return v + + +def apply_bias(g, dcn_params, v, bias): + """ + Apply bias parameter. + + Args: + g: graph object. + dcn_params: parameters for deform_conv2d. + v: input tensor. + The shape is (b, out_ch, out_h, out_w). + bias: bias tensor. + The shape is (out_ch,). + + Returns: + The shape is (b, out_ch, out_h, out_w). + """ + bias = unsqueeze(g, bias, [0, 2, 3]) + v = add(g, v, bias) + return v + + +def create_dcn_params(input, weight, offset, mask, bias, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, n_weight_grps, + n_offset_grps, use_mask, option): + """ + Manage parameters for DeformConv2d. + """ + additional_pad_h = additional_pad_w = 0 + if pad_h == 0: + additional_pad_h = 1 + if pad_w == 0: + additional_pad_w = 1 + + batch = get_tensor_dim_size(input, 0) + in_ch = get_tensor_dim_size(input, 1) + in_h = get_tensor_dim_size(input, 2) + 2 * (pad_h + additional_pad_h) + in_w = get_tensor_dim_size(input, 3) + 2 * (pad_w + additional_pad_w) + in_ch_per_group = in_ch // n_offset_grps + + out_ch = get_tensor_dim_size(weight, 0) + kernel_h = get_tensor_dim_size(weight, 2) + kernel_w = get_tensor_dim_size(weight, 3) + kernel_area_size = kernel_h * kernel_w + + out_h = get_tensor_dim_size(offset, 2) + out_w = get_tensor_dim_size(offset, 3) + + if JitScalarType is not None and hasattr(JitScalarType, "from_value"): + # 2.0 and later + scalar_type = JitScalarType.from_value(offset) + offset_dtype_onnx = scalar_type.onnx_type() + offset_dtype_pytorch = scalar_type.dtype() + + scalar_type = JitScalarType.from_dtype(torch.int64) + index_dtype_onnx = scalar_type.onnx_type() + index_dtype_pytorch = scalar_type.dtype() + else: + offset_dtype = sym_help._try_get_scalar_type(offset) + offset_dtype_onnx = sym_help.cast_pytorch_to_onnx[offset_dtype] + dtype_idx = sym_help.scalar_type_to_onnx.index(offset_dtype_onnx) + offset_dtype_pytorch = sym_help.scalar_type_to_pytorch_type[dtype_idx] + + index_dtype = "Long" + index_dtype_onnx = sym_help.cast_pytorch_to_onnx[index_dtype] + dtype_idx = sym_help.scalar_type_to_onnx.index(index_dtype_onnx) + index_dtype_pytorch = sym_help.scalar_type_to_pytorch_type[dtype_idx] + + dcn_params = { + # batch and kernel + "batch": batch, + "kernel_h": kernel_h, + "kernel_w": kernel_w, + "kernel_area_size": kernel_area_size, + + # input size + "in_ch": in_ch, + "in_ch_per_group": in_ch_per_group, + "in_h": in_h, + "in_w": in_w, + + # output size + "out_ch": out_ch, + "out_h": out_h, + "out_w": out_w, + + # other parameters + "stride_h": stride_h, + "stride_w": stride_w, + "dilation_h": dilation_h, + "dilation_w": dilation_w, + "n_offset_grps": n_offset_grps, + "n_weight_grps": n_weight_grps, + + # offset data type + "offset_dtype_onnx": offset_dtype_onnx, + "offset_dtype_pytorch": offset_dtype_pytorch, + + # index data type + "index_dtype_onnx": index_dtype_onnx, + "index_dtype_pytorch": index_dtype_pytorch, + + # padding + "padding_h": pad_h, + "padding_w": pad_w, + "additional_pad_h": additional_pad_h, + "additional_pad_w": additional_pad_w, + + "option": option, + } + return dcn_params + + +def deform_conv2d_func(use_gathernd, enable_openvino_patch): + @sym_help.parse_args("v", "v", "v", "v", "v", "i", "i", "i", "i", "i", "i", + "i", "i", "b") + def deform_conv2d(g, input, weight, offset, mask, bias, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, n_weight_grps, + n_offset_grps, use_mask): + option = { + "use_gathernd": use_gathernd, + "enable_openvino_patch": enable_openvino_patch, + } + dcn_params = create_dcn_params(input, weight, offset, mask, bias, + stride_h, stride_w, pad_h, pad_w, + dilation_h, dilation_w, n_weight_grps, + n_offset_grps, use_mask, option) + + p = calculate_p(g, dcn_params, offset) + p_floor = calculate_p_floor(g, dcn_params, p) + p_tlbr = calculate_p_tlbr(g, dcn_params, p_floor) + weight_tlbr = calculate_weight(g, dcn_params, p, p_floor) + + input = reshape_input_for_gather_elements(g, dcn_params, input) + v_tlbr = gather_elements_tlbr(g, dcn_params, input, p_tlbr) + + v = calculate_weighted_sum(g, dcn_params, v_tlbr, weight_tlbr) + + if use_mask: + v = apply_mask(g, dcn_params, v, mask) + + v = reshape_v_for_conv(g, dcn_params, v) + v = apply_conv(g, dcn_params, v, weight) + v = apply_bias(g, dcn_params, v, bias) + return v + + return deform_conv2d + + +def register_deform_conv2d_onnx_op(use_gathernd=True, + enable_openvino_patch=False): + """ + Register ONNX operator for torchvision::deform_conv2d. + + Args: + use_gathernd: If True, use GatherND. Otherwise use GatherElements. + enable_openvino_patch: If True, enable patch for OpenVINO. + Otherwise, disable it. + """ + register_custom_op_symbolic( + 'torchvision::deform_conv2d', + deform_conv2d_func(use_gathernd, enable_openvino_patch), + onnx_opset_version) diff --git a/lightglue_onnx/aliked/padder.py b/lightglue_onnx/aliked/padder.py new file mode 100644 index 0000000..efd102a --- /dev/null +++ b/lightglue_onnx/aliked/padder.py @@ -0,0 +1,24 @@ +from torch import Tensor +import torch.nn.functional as F + + +class InputPadder(object): + """ Pads images such that dimensions are divisible by 8 """ + + def __init__(self, h: int, w: int, divis_by: int=8): + self.ht = h + self.wd = w + pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by + pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2] + + def pad(self, x: Tensor): + assert x.ndim == 4 + return F.pad(x, self._pad, mode='replicate') + + def unpad(self, x: Tensor): + assert x.ndim == 4 + ht = x.shape[-2] + wd = x.shape[-1] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] \ No newline at end of file diff --git a/lightglue_onnx/aliked/soft_detect.py b/lightglue_onnx/aliked/soft_detect.py new file mode 100644 index 0000000..adc63c0 --- /dev/null +++ b/lightglue_onnx/aliked/soft_detect.py @@ -0,0 +1,242 @@ +import torch +from torch import nn, Tensor + +# coordinates system +# ------------------------------> [ x: range=-1.0~1.0; w: range=0~W ] +# | ----------------------------- +# | | | +# | | | +# | | | +# | | image | +# | | | +# | | | +# | | | +# | |---------------------------| +# v +# [ y: range=-1.0~1.0; h: range=0~H ] + + +def simple_nms(scores: Tensor, nms_radius: int): + """Fast Non-maximum suppression to remove nearby points""" + + zeros = torch.zeros_like(scores) + max_mask = scores == torch.nn.functional.max_pool2d( + scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius + ) + + for _ in range(2): + supp_mask = ( + torch.nn.functional.max_pool2d( + max_mask.float(), + kernel_size=nms_radius * 2 + 1, + stride=1, + padding=nms_radius, + ) + > 0 + ) + supp_scores = torch.where(supp_mask, zeros, scores) + new_max_mask = supp_scores == torch.nn.functional.max_pool2d( + supp_scores, + kernel_size=nms_radius * 2 + 1, + stride=1, + padding=nms_radius, + ) + max_mask = max_mask | (new_max_mask & (~supp_mask)) + return torch.where(max_mask, scores, zeros) + + +class DKD(nn.Module): + def __init__( + self, + radius: int = 2, + top_k: int = 0, + scores_th: float = 0.2, + n_limit: int = 20000, + script: bool = False, + ): + """ + Args: + radius: soft detection radius, kernel size is (2 * radius + 1) + top_k: top_k > 0: return top k keypoints + scores_th: top_k <= 0 threshold mode: scores_th > 0: return keypoints with scores>scores_th + else: return keypoints with scores > scores.mean() + n_limit: max number of keypoint in threshold mode + """ + super().__init__() + self.radius = radius + self.top_k = top_k + self.scores_th = scores_th + self.n_limit = n_limit + self.kernel_size = 2 * self.radius + 1 + self.temperature = 0.1 # tuned temperature + self.unfold = nn.Unfold( + kernel_size=self.kernel_size, padding=self.radius + ) + # self.get_patches_func = get_patches_script if script else get_patches.apply + + # local xy grid + x = torch.linspace(-self.radius, self.radius, self.kernel_size) + # (kernel_size*kernel_size) x 2 : (w,h) + self.hw_grid = ( + torch.stack(torch.meshgrid([x, x])).view(2, -1).t()[:, [1, 0]] + ).to("cuda") + # self.hw_grid = ( + # torch.stack(torch.meshgrid([x, x])).view(2, -1).t()[:, [1, 0]] + # ).to("cuda").half() + + def detect_keypoints(self, scores_map: Tensor, sub_pixel: bool = True): + b, c, h, w = scores_map.shape + scores_nograd = scores_map.detach() + nms_scores = simple_nms(scores_nograd, 2) + + # remove border + nms_scores[:, :, : self.radius, :] = 0 + nms_scores[:, :, :, : self.radius] = 0 + nms_scores[:, :, h - self.radius:, :] = 0 + nms_scores[:, :, :, w - self.radius:] = 0 + + # detect keypoints without grad + if self.top_k > 0: + topk = torch.topk(nms_scores.view(b, -1), self.top_k) + indices_keypoints = [ + topk.indices[i] for i in range(b) + ] # B x top_k + else: + if self.scores_th > 0: + masks = nms_scores > self.scores_th + if masks.sum() == 0: + th = scores_nograd.reshape(b, -1).mean( + dim=1 + ) # th = self.scores_th + masks = nms_scores > th.reshape(b, 1, 1, 1) + else: + th = scores_nograd.reshape(b, -1).mean( + dim=1 + ) # th = self.scores_th + masks = nms_scores > th.reshape(b, 1, 1, 1) + masks = masks.reshape(b, -1) + + indices_keypoints = [] # list, B x (any size) + scores_view = scores_nograd.reshape(b, -1) + for mask, scores in zip(masks, scores_view): + indices = mask.nonzero()[:, 0] + if len(indices) > self.n_limit: + kpts_sc = scores[indices] + sort_idx = kpts_sc.sort(descending=True)[1] + sel_idx = sort_idx[: self.n_limit] + indices = indices[sel_idx] + indices_keypoints.append(indices) + + wh = torch.tensor([w - 1, h - 1], device=scores_nograd.device) + + keypoints = [] + scoredispersitys = [] + kptscores = [] + if sub_pixel: + # detect soft keypoints with grad backpropagation + patches = self.unfold(scores_map) # B x (kernel**2) x (H*W) + self.hw_grid = self.hw_grid.to(scores_map) # to device + for b_idx in range(b): + patch = patches[b_idx].t() # (H*W) x (kernel**2) + indices_kpt = indices_keypoints[ + b_idx + ] # one dimension vector, say its size is M + patch_scores = patch[indices_kpt] # M x (kernel**2) + keypoints_xy_nms = torch.stack( + [ + indices_kpt % w, + torch.div(indices_kpt, w, rounding_mode="trunc"), + ], + dim=1, + ) # Mx2 + + # max is detached to prevent undesired backprop loops in the graph + max_v = patch_scores.max(dim=1).values.detach()[:, None] + x_exp = ( + (patch_scores - max_v) / self.temperature + ).exp() # M * (kernel**2), in [0, 1] + + # \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} } + xy_residual = ( + x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None] + ) # Soft-argmax, Mx2 + + hw_grid_dist2 = ( + torch.norm( + (self.hw_grid[None, :, :] - xy_residual[:, None, :]) + / self.radius, + dim=-1, + ) + ** 2 + ) + scoredispersity = (x_exp * hw_grid_dist2).sum( + dim=1 + ) / x_exp.sum(dim=1) + + # compute result keypoints + keypoints_xy = keypoints_xy_nms + xy_residual + keypoints_xy = ( + keypoints_xy / wh * 2 - 1 + ) # (w,h) -> (-1~1,-1~1) + + kptscore = torch.nn.functional.grid_sample( + scores_map[b_idx].unsqueeze(0), + keypoints_xy.view(1, 1, -1, 2), + mode="bilinear", + align_corners=True, + )[ + 0, 0, 0, : + ] # CxN + + keypoints.append(keypoints_xy) + scoredispersitys.append(scoredispersity) + kptscores.append(kptscore) + else: + for b_idx in range(b): + indices_kpt = indices_keypoints[ + b_idx + ] # one dimension vector, say its size is M + # To avoid warning: UserWarning: __floordiv__ is deprecated + keypoints_xy_nms = torch.stack( + [ + indices_kpt % w, + torch.div(indices_kpt, w, rounding_mode="trunc"), + ], + dim=1, + ) # Mx2 + keypoints_xy = ( + keypoints_xy_nms / wh * 2 - 1 + ) # (w,h) -> (-1~1,-1~1) + kptscore = torch.nn.functional.grid_sample( + scores_map[b_idx].unsqueeze(0), + keypoints_xy.view(1, 1, -1, 2), + mode="bilinear", + align_corners=True, + )[ + 0, 0, 0, : + ] # CxN + keypoints.append(keypoints_xy) + scoredispersitys.append( + kptscore + ) # for jit.script compatability + kptscores.append(kptscore) + + return keypoints, scoredispersitys, kptscores + + def forward(self, scores_map: Tensor, sub_pixel: bool = True): + """ + :param scores_map: Bx1xHxW + :param descriptor_map: BxCxHxW + :param sub_pixel: whether to use sub-pixel keypoint detection + :return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1.0 ~ 1.0 + """ + + keypoints, scoredispersitys, kptscores = self.detect_keypoints( + scores_map, sub_pixel + ) + + keypoints = torch.stack(keypoints, dim=0) + kptscores = torch.stack(kptscores, dim=0) + scoredispersitys = torch.stack(scoredispersitys, dim=0) + + return keypoints, kptscores, scoredispersitys diff --git a/lightglue_onnx/lightglue.py b/lightglue_onnx/lightglue.py index eec3e8a..9f919a9 100644 --- a/lightglue_onnx/lightglue.py +++ b/lightglue_onnx/lightglue.py @@ -141,7 +141,8 @@ def forward(self, x0: torch.Tensor, x1: torch.Tensor) -> Tuple[torch.Tensor]: m1 = self.inner_attn(qk1, qk0, v0) m0, m1 = map( - lambda t: t.transpose(1, 2).reshape(self.batch, -1, self.embed_dim), + lambda t: t.transpose(1, 2).reshape( + self.batch, -1, self.embed_dim), (m0, m1), ) m0, m1 = map(self.to_out, (m0, m1)) @@ -231,7 +232,8 @@ def filter_matches(scores: torch.Tensor, th: float): class LightGlue(nn.Module): default_conf = { "name": "lightglue", # just for interfacing - "input_dim": 256, # input descriptor dimension (autoselected from weights) + # input descriptor dimension (autoselected from weights) + "input_dim": 256, "descriptor_dim": 256, "n_layers": 9, "num_heads": 4, @@ -247,6 +249,10 @@ class LightGlue(nn.Module): features = { "superpoint": ("superpoint_lightglue", 256), "disk": ("disk_lightglue", 128), + "aliked-t16": ("aliked_lightglue", 64), + "aliked-n16": ("aliked_lightglue", 128), + "aliked-n16rot": ("aliked_lightglue", 128), + "aliked-n32": ("aliked_lightglue", 128), } def __init__(self, features="superpoint", **conf) -> None: @@ -258,7 +264,8 @@ def __init__(self, features="superpoint", **conf) -> None: self.conf = conf = SimpleNamespace(**self.conf) if conf.input_dim != conf.descriptor_dim: - self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True) + self.input_proj = nn.Linear( + conf.input_dim, conf.descriptor_dim, bias=True) else: self.input_proj = nn.Identity() @@ -267,9 +274,11 @@ def __init__(self, features="superpoint", **conf) -> None: h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim - self.transformers = nn.ModuleList([TransformerLayer(d, h) for _ in range(n)]) + self.transformers = nn.ModuleList( + [TransformerLayer(d, h) for _ in range(n)]) - self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)]) + self.log_assignment = nn.ModuleList( + [MatchAssignment(d) for _ in range(n)]) self.token_confidence = nn.ModuleList( [TokenConfidence(d) for _ in range(n - 1)] @@ -283,7 +292,7 @@ def __init__(self, features="superpoint", **conf) -> None: if features is not None: fname = f"{conf.weights}_{self.version}.pth".replace(".", "-") state_dict = torch.hub.load_state_dict_from_url( - self.url.format(self.version, features), file_name=fname + self.url.format(self.version, features.split('-')[0]), file_name=fname ) elif conf.weights is not None: path = Path(__file__).parent @@ -294,9 +303,11 @@ def __init__(self, features="superpoint", **conf) -> None: # rename old state dict entries for i in range(n): pattern = f"self_attn.{i}", f"transformers.{i}.self_attn" - state_dict = {k.replace(*pattern): v for k, v in state_dict.items()} + state_dict = {k.replace(*pattern): v for k, + v in state_dict.items()} pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn" - state_dict = {k.replace(*pattern): v for k, v in state_dict.items()} + state_dict = {k.replace(*pattern): v for k, + v in state_dict.items()} self.load_state_dict(state_dict, strict=False) print("Loaded LightGlue model") @@ -327,7 +338,8 @@ def forward( for i in range(self.conf.n_layers): # self+cross attention - desc0, desc1 = self.transformers[i](desc0, desc1, encoding0, encoding1) + desc0, desc1 = self.transformers[i]( + desc0, desc1, encoding0, encoding1) if i == self.conf.n_layers - 1: continue # no early stopping or adaptive width at last layer @@ -357,7 +369,8 @@ def forward( matches, mscores = filter_matches(scores, self.conf.filter_threshold) return matches, mscores # Skip unnecessary computation - m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold) + m0, m1, mscores0, mscores1 = filter_matches( + scores, self.conf.filter_threshold) valid = m0[0] > -1 m_indices_0 = torch.where(valid)[0] @@ -372,8 +385,10 @@ def forward( if do_point_pruning: # scatter with indices after pruning m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype) m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype) - m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0))) - m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0))) + m0_[:, ind0] = torch.where( + m0 == -1, -1, ind1.gather(1, m0.clamp(min=0))) + m1_[:, ind1] = torch.where( + m1 == -1, -1, ind0.gather(1, m1.clamp(min=0))) mscores0_ = torch.zeros((b, m), device=mscores0.device) mscores1_ = torch.zeros((b, n), device=mscores1.device) mscores0_[:, ind0] = mscores0 @@ -409,5 +424,6 @@ def check_if_stop( """evaluate stopping condition""" confidences = torch.cat([confidences0, confidences1], -1) threshold = self.confidence_thresholds[layer_index] - ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points + ratio_confident = 1.0 - \ + (confidences < threshold).float().sum() / num_points return ratio_confident > self.conf.depth_confidence