File tree Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Original file line number Diff line number Diff line change @@ -210,6 +210,7 @@ def test_model_backward(model_name, batch_size):
210
210
pytest .skip ("Fixed input size model > limit." )
211
211
212
212
model = create_model (model_name , pretrained = False , num_classes = 42 )
213
+ encoder_only = model .num_classes == 0 # FIXME better approach?
213
214
num_params = sum ([x .numel () for x in model .parameters ()])
214
215
model .train ()
215
216
@@ -224,7 +225,12 @@ def test_model_backward(model_name, batch_size):
224
225
assert x .grad is not None , f'No gradient for { n } '
225
226
num_grad = sum ([x .grad .numel () for x in model .parameters () if x .grad is not None ])
226
227
227
- assert outputs .shape [- 1 ] == 42
228
+ if encoder_only :
229
+ output_fmt = getattr (model , 'output_fmt' , 'NCHW' )
230
+ feat_axis = get_channel_dim (output_fmt )
231
+ assert outputs .shape [feat_axis ] == model .num_features , f'unpooled feature dim { outputs .shape [feat_axis ]} != model.num_features { model .num_features } '
232
+ else :
233
+ assert outputs .shape [- 1 ] == 42
228
234
assert num_params == num_grad , 'Some parameters are missing gradients'
229
235
assert not torch .isnan (outputs ).any (), 'Output included NaNs'
230
236
You can’t perform that action at this time.
0 commit comments