File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -69,7 +69,7 @@ def __init__(
69
69
70
70
self .norm = norm_layer (self .out_channels )
71
71
72
- def forward (self , inputs : list [torch .Tensor ]) -> torch .Tensor :
72
+ def forward (self , inputs : List [torch .Tensor ]) -> torch .Tensor :
73
73
# Inputs list of [B, C, H, W] tensors
74
74
high_resolution = inputs [0 ].shape [- 2 :] # Assuming the first input is the highest resolution.
75
75
resized_inputs = []
@@ -81,7 +81,7 @@ def forward(self, inputs: list[torch.Tensor]) -> torch.Tensor:
81
81
channel_cat_imgs = torch .cat (resized_inputs , dim = 1 ) # Cat on channel dim, must equal self.in_channels
82
82
img = self .ffn (channel_cat_imgs )
83
83
84
- if any ([ ro != rh for ro , rh in zip ( high_resolution , self .output_resolution )]) :
84
+ if high_resolution [ 0 ] != self . output_resolution [ 0 ] or high_resolution [ 1 ] != self .output_resolution [ 1 ] :
85
85
# Interpolate / pool to target output_resolution if highest feature resolution differs
86
86
if (
87
87
high_resolution [0 ] % self .output_resolution [0 ] != 0 or
You can’t perform that action at this time.
0 commit comments