Skip to content

Commit ddd3f99

Browse files
committed
Update test, encoder_only mode for backward test
1 parent 4cc7fdb commit ddd3f99

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

tests/test_models.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def test_model_backward(model_name, batch_size):
210210
pytest.skip("Fixed input size model > limit.")
211211

212212
model = create_model(model_name, pretrained=False, num_classes=42)
213+
encoder_only = model.num_classes == 0 # FIXME better approach?
213214
num_params = sum([x.numel() for x in model.parameters()])
214215
model.train()
215216

@@ -224,7 +225,12 @@ def test_model_backward(model_name, batch_size):
224225
assert x.grad is not None, f'No gradient for {n}'
225226
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
226227

227-
assert outputs.shape[-1] == 42
228+
if encoder_only:
229+
output_fmt = getattr(model, 'output_fmt', 'NCHW')
230+
feat_axis = get_channel_dim(output_fmt)
231+
assert outputs.shape[feat_axis] == model.num_features, f'unpooled feature dim {outputs.shape[feat_axis]} != model.num_features {model.num_features}'
232+
else:
233+
assert outputs.shape[-1] == 42
228234
assert num_params == num_grad, 'Some parameters are missing gradients'
229235
assert not torch.isnan(outputs).any(), 'Output included NaNs'
230236

0 commit comments

Comments
 (0)