Skip to content

Commit 664a309

Browse files
committed
feat(modules): add star-relu
1 parent dee38ed commit 664a309

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

cellseg_models_pytorch/modules/act/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from .gated_gelu import GEGLU, ApproximateGELU
2121
from .mish import Mish
22+
from .star_relu import StarReLU
2223
from .swish import Swish
2324

2425
ACT_LOOKUP = {
@@ -40,11 +41,11 @@
4041
"hardshrink": Hardshrink,
4142
"tanhshrink": Tanhshrink,
4243
"hardsigmoid": Hardsigmoid,
44+
"star_relu": StarReLU,
45+
"geglu": GEGLU,
46+
"approximate_geglu": ApproximateGELU,
4347
}
4448

45-
TR_ACT_LOOKUP = {"geglu": GEGLU, "approximate_geglu": ApproximateGELU}
46-
47-
4849
__all__ = [
4950
"Mish",
5051
"Swish",
@@ -68,5 +69,4 @@
6869
"ACT_LOOKUP",
6970
"GEGLU",
7071
"ApproximateGELU",
71-
"TR_ACT_LOOKUP",
7272
]
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
__all__ = ["StarReLU"]
5+
6+
7+
class StarReLU(nn.Module):
8+
def __init__(
9+
self,
10+
scale_value: float = 1.0,
11+
bias_value: float = 0.0,
12+
scale_learnable: bool = True,
13+
bias_learnable: bool = True,
14+
inplace: bool = False,
15+
) -> None:
16+
"""Apply StarReLU activation.
17+
18+
Adapted from:
19+
https://github.com/sail-sg/metaformer/blob/main/metaformer_baselines.py
20+
21+
See MetaFormer: https://arxiv.org/abs/2210.13452
22+
23+
StarReLU: s * relu(x) ** 2 + b
24+
25+
Parameters
26+
----------
27+
scale_value : float, default=1.0
28+
Learnable scaling factor for relu activation.
29+
bias_value : float, default=0.0
30+
Learnable bias term for relu activation.
31+
scale_learnable : bool, default=True
32+
Flag, whether to keep the scale factor learnable.
33+
bias_learnable : bool, default=True
34+
Flag, whether to keep the bias term learnable.
35+
inplace : bool, default=False
36+
Flag whether to apply inplace-relu.
37+
"""
38+
super().__init__()
39+
self.inplace = inplace
40+
self.relu = nn.ReLU(inplace=inplace)
41+
self.scale = nn.Parameter(
42+
scale_value * torch.ones(1), requires_grad=scale_learnable
43+
)
44+
self.bias = nn.Parameter(
45+
bias_value * torch.ones(1), requires_grad=bias_learnable
46+
)
47+
48+
def forward(self, x: torch.Tensor) -> torch.Tensor:
49+
"""Forward pass of the StarReLU."""
50+
return self.scale * self.relu(x) ** 2 + self.bias

0 commit comments

Comments
 (0)