Skip to content

Commit 14e8d91

Browse files
committed
test: update tests
1 parent 0d8c05e commit 14e8d91

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

cellseg_models_pytorch/decoders/tests/test_decoders.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
from cellseg_models_pytorch.decoders import Decoder
55

66

7-
@pytest.mark.parametrize("long_skip", ["unet", "unetpp", "unet3p", "unet3p-lite"])
7+
@pytest.mark.parametrize(
8+
"long_skip", ["unet", "unetpp", "unet3p", "unet3p-lite", "cross-attn"]
9+
)
810
@pytest.mark.parametrize("merge_policy", ["cat", "sum"])
911
@pytest.mark.parametrize("use_conv", [True, False])
1012
@pytest.mark.parametrize("use_tr", [True, False])
1113
def test_decoder_fwdbwd(long_skip, merge_policy, use_conv, use_tr):
12-
enc_channels = (64, 32, 64, 32)
14+
enc_channels = (32, 32, 32, 32)
1315
out_dims = [32 // 2**i for i in range(4)][::-1]
1416

1517
decoder1_kwargs = {"merge_policy": merge_policy}
@@ -40,8 +42,8 @@ def test_decoder_fwdbwd(long_skip, merge_policy, use_conv, use_tr):
4042
decoder = Decoder(
4143
enc_channels=enc_channels,
4244
out_channels=(32, 32, 32, 32),
43-
n_layers=n_layers,
44-
n_blocks=n_blocks,
45+
n_conv_layers=n_layers,
46+
n_conv_blocks=n_blocks,
4547
n_transformers=n_tr_layers,
4648
n_transformer_blocks=n_tr_blocks,
4749
long_skip=long_skip,

cellseg_models_pytorch/models/tests/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def test_multitaskunet_fwdbwd():
8484
m = MultiTaskUnet(
8585
decoders=("sem",),
8686
heads={"sem": {"sem": 3}},
87-
n_layers={"sem": (1, 1, 1, 1)},
88-
n_blocks={"sem": ((2,), (2,), (2,), (2,))},
87+
n_conv_layers={"sem": (1, 1, 1, 1)},
88+
n_conv_blocks={"sem": ((2,), (2,), (2,), (2,))},
8989
out_channels={"sem": (128, 64, 32, 16)},
9090
long_skips={"sem": "unet"},
9191
dec_params={"sem": None},

cellseg_models_pytorch/modules/tests/test_transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def test_transformer(block_type, computation_type):
2121
computation_types=(computation_type,),
2222
biases=(False,),
2323
dropouts=(0.0,),
24+
layer_scales=(False,),
2425
slice_size=4,
2526
seq_len=H * W,
2627
)

0 commit comments

Comments
 (0)