Skip to content

Commit 5129f73

Browse files
committed
test(modules): update transformer tests
1 parent 5619963 commit 5129f73

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

cellseg_models_pytorch/modules/tests/test_transformer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from cellseg_models_pytorch.modules import Transformer2D
55

66

7-
@pytest.mark.parametrize("block_type", ["basic", "slice"])
8-
def test_transformer(block_type):
7+
@pytest.mark.parametrize("block_type", ["exact", "linformer"])
8+
@pytest.mark.parametrize("computation_type", ["basic", "slice"])
9+
def test_transformer(block_type, computation_type):
910
in_channels = 64
1011
B = 4
1112
H = W = 32
@@ -17,9 +18,11 @@ def test_transformer(block_type):
1718
head_dim=32,
1819
n_blocks=1,
1920
block_types=(block_type,),
21+
computation_types=(computation_type,),
2022
biases=(False,),
2123
dropouts=(0.0,),
2224
slice_size=4,
25+
seq_len=H * W,
2326
)
2427

2528
out = tr(x)

0 commit comments

Comments
 (0)