Skip to content

Commit 1690574

Browse files
committed
Fix torchscript compat of MobileNetV5 MSFA
1 parent e83e251 commit 1690574

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

timm/models/mobilenetv5.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(
6969

7070
self.norm = norm_layer(self.out_channels)
7171

72-
def forward(self, inputs: list[torch.Tensor]) -> torch.Tensor:
72+
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
7373
# Inputs list of [B, C, H, W] tensors
7474
high_resolution = inputs[0].shape[-2:] # Assuming the first input is the highest resolution.
7575
resized_inputs = []
@@ -81,7 +81,7 @@ def forward(self, inputs: list[torch.Tensor]) -> torch.Tensor:
8181
channel_cat_imgs = torch.cat(resized_inputs, dim=1) # Cat on channel dim, must equal self.in_channels
8282
img = self.ffn(channel_cat_imgs)
8383

84-
if any([ro != rh for ro, rh in zip(high_resolution, self.output_resolution)]):
84+
if high_resolution[0] != self.output_resolution[0] or high_resolution[1] != self.output_resolution[1]:
8585
# Interpolate / pool to target output_resolution if highest feature resolution differs
8686
if (
8787
high_resolution[0] % self.output_resolution[0] != 0 or

0 commit comments

Comments
 (0)