Skip to content

Commit e4ae5e0

Browse files
committed
fix(modules): fix layer norm for chl-attn
1 parent 664a309 commit e4ae5e0

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

cellseg_models_pytorch/modules/attention_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def __init__(
257257
padding=0,
258258
bias=True,
259259
)
260-
self.norm = Norm("ln", num_features=squeeze_channels)
260+
self.norm = Norm("ln2d", num_features=squeeze_channels)
261261
self.act = Activation(activation)
262262
self.conv_excite = Conv(
263263
conv,

cellseg_models_pytorch/modules/tests/test_basemodules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def test_act_forward(activation):
1313
assert output.dtype == input.dtype
1414

1515

16-
@pytest.mark.parametrize("normalization", ["bn", "bcn", "gn", "ln", None])
16+
@pytest.mark.parametrize("normalization", ["bn", "bcn", "gn", "ln2d", None])
1717
def test_norm(normalization):
1818
norm = Norm(normalization, num_features=3)
1919
input = torch.rand([1, 3, 16, 16])

0 commit comments

Comments
 (0)