Skip to content

Commit 857727d

Browse files
committed
Simplify resolution check for improved script/trace compat
1 parent e0cb669 commit 857727d

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

timm/models/mobilenetv5.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
7474
high_resolution = inputs[0].shape[-2:] # Assuming the first input is the highest resolution.
7575
resized_inputs = []
7676
for _, img in enumerate(inputs):
77-
if any([r < hr for r, hr in zip(img.shape[-2:], high_resolution)]):
77+
feat_size = img.shape[-2:]
78+
if feat_size[0] < high_resolution[0] or feat_size[1] < high_resolution[1]:
7879
img = F.interpolate(img, size=high_resolution, mode=self.interpolation_mode)
7980
resized_inputs.append(img)
8081

0 commit comments

Comments
 (0)