We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 5619963 commit 5129f73Copy full SHA for 5129f73
cellseg_models_pytorch/modules/tests/test_transformer.py
@@ -4,8 +4,9 @@
4
from cellseg_models_pytorch.modules import Transformer2D
5
6
7
-@pytest.mark.parametrize("block_type", ["basic", "slice"])
8
-def test_transformer(block_type):
+@pytest.mark.parametrize("block_type", ["exact", "linformer"])
+@pytest.mark.parametrize("computation_type", ["basic", "slice"])
9
+def test_transformer(block_type, computation_type):
10
in_channels = 64
11
B = 4
12
H = W = 32
@@ -17,9 +18,11 @@ def test_transformer(block_type):
17
18
head_dim=32,
19
n_blocks=1,
20
block_types=(block_type,),
21
+ computation_types=(computation_type,),
22
biases=(False,),
23
dropouts=(0.0,),
24
slice_size=4,
25
+ seq_len=H * W,
26
)
27
28
out = tr(x)
0 commit comments