Skip to content

[ONNX] Add dynamic shapes support (& in-browser inference w/ Transformers.js) #79

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions rfdetr/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,8 @@ def export(self, output_dir="output", infer_dir=None, simplify=False, backbone_

device = self.device
model = deepcopy(self.model.to("cpu"))
if backbone_only:
model = model.backbone
model.to(device)

os.makedirs(output_dir, exist_ok=True)
Expand All @@ -453,11 +455,25 @@ def export(self, output_dir="output", infer_dir=None, simplify=False, backbone_
input_tensors = make_infer_image(infer_dir, shape, batch_size, device).to(device)
input_names = ['input']
output_names = ['features'] if backbone_only else ['dets', 'labels']
dynamic_axes = None
dynamic_axes = {
'input': {0: 'batch_size', 2: 'height', 3: 'width'},
}
if backbone_only:
dynamic_axes.update({
'dets': {0: 'batch_size'},
'labels': {0: 'batch_size'},
})
else:
dynamic_axes.update({
'features': {0: 'batch_size', 2: 'num_patches_height', 3: 'num_patches_width'},
})

self.model.eval()
with torch.no_grad():
if backbone_only:
features = model(input_tensors)
features = model(
utils.nested_tensor_from_tensor_list([input_tensors[0]])
)[0][0].tensors
print(f"PyTorch inference output shape: {features.shape}")
else:
outputs = model(input_tensors)
Expand Down
55 changes: 0 additions & 55 deletions rfdetr/models/backbone/dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,61 +67,6 @@ def export(self):
if self._export:
return
self._export = True
shape = self.shape
def make_new_interpolated_pos_encoding(
position_embeddings, patch_size, height, width
):

num_positions = position_embeddings.shape[1] - 1
dim = position_embeddings.shape[-1]
height = height // patch_size
width = width // patch_size

class_pos_embed = position_embeddings[:, 0]
patch_pos_embed = position_embeddings[:, 1:]

# Reshape and permute
patch_pos_embed = patch_pos_embed.reshape(
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

# Use bilinear interpolation without antialias
patch_pos_embed = F.interpolate(
patch_pos_embed,
size=(height, width),
mode="bicubic",
align_corners=False,
antialias=True,
)

# Reshape back
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

# If the shape of self.encoder.embeddings.position_embeddings
# matches the shape of your new tensor, use copy_:
with torch.no_grad():
new_positions = make_new_interpolated_pos_encoding(
self.encoder.embeddings.position_embeddings,
self.encoder.config.patch_size,
shape[0],
shape[1],
)
# Create a new Parameter with the new size
old_interpolate_pos_encoding = self.encoder.embeddings.interpolate_pos_encoding
def new_interpolate_pos_encoding(self_mod, embeddings, height, width):
num_patches = embeddings.shape[1] - 1
num_positions = self_mod.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self_mod.position_embeddings
return old_interpolate_pos_encoding(embeddings, height, width)

self.encoder.embeddings.position_embeddings = nn.Parameter(new_positions)
self.encoder.embeddings.interpolate_pos_encoding = types.MethodType(
new_interpolate_pos_encoding,
self.encoder.embeddings
)

def forward(self, x):
assert x.shape[2] % 14 == 0 and x.shape[3] % 14 == 0, f"Dinov2 requires input shape to be divisible by 14, but got {x.shape}"
Expand Down
3 changes: 2 additions & 1 deletion rfdetr/models/backbone/dinov2_with_windowed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width:
size=(torch_int(height), torch_int(width)), # Explicit size instead of scale_factor
mode="bicubic",
align_corners=False,
antialias=True,
# True by default, False during export
antialias=not torch.onnx.is_in_onnx_export(),
).to(dtype=target_dtype)

# Validate output dimensions if not tracing
Expand Down
2 changes: 1 addition & 1 deletion rfdetr/models/ops/modules/ms_deform_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _reset_parameters(self):

def forward(self, query, reference_points, input_flatten, input_spatial_shapes,
input_level_start_index, input_padding_mask=None):
"""
r"""
:param query (N, Length_{query}, C)
:param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
Expand Down
9 changes: 4 additions & 5 deletions rfdetr/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def gen_sineembed_for_position(pos_tensor, dim=128):


def gen_encoder_output_proposals(memory, memory_padding_mask, spatial_shapes, unsigmoid=True):
"""
r"""
Input:
- memory: bs, \sum{hw}, d_model
- memory_padding_mask: bs, \sum{hw}
Expand Down Expand Up @@ -198,12 +198,12 @@ def forward(self, srcs, masks, pos_embeds, refpoint_embed, query_feat):
src_flatten = []
mask_flatten = [] if masks is not None else None
lvl_pos_embed_flatten = []
spatial_shapes = []
spatial_shapes = torch.empty((len(srcs), 2), device=srcs[0].device, dtype=torch.long)
valid_ratios = [] if masks is not None else None
for lvl, (src, pos_embed) in enumerate(zip(srcs, pos_embeds)):
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
spatial_shapes[lvl, 0] = h
spatial_shapes[lvl, 1] = w

src = src.flatten(2).transpose(1, 2) # bs, hw, c
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
Expand All @@ -217,7 +217,6 @@ def forward(self, srcs, masks, pos_embeds, refpoint_embed, query_feat):
mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=memory.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))

if self.two_stage:
Expand Down