Skip to content

Commit 4a2aa89

Browse files
committed
refactor: adjust acts to take in arbitrary kwargs
1 parent 5d89980 commit 4a2aa89

File tree

5 files changed

+7
-6
lines changed

5 files changed

+7
-6
lines changed

cellseg_models_pytorch/modules/act/gated_gelu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727

2828
class GEGLU(nn.Module):
29-
def __init__(self, dim_in: int, dim_out: int):
29+
def __init__(self, dim_in: int, dim_out: int, **kwargs) -> None:
3030
"""Apply a variant of the gated linear unit activation function.
3131
3232
https://arxiv.org/abs/2002.05202.
@@ -52,7 +52,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5252

5353

5454
class ApproximateGELU(nn.Module):
55-
def __init__(self, dim_in: int, dim_out: int):
55+
def __init__(self, dim_in: int, dim_out: int, **kwargs) -> None:
5656
"""Apply the approximate form of Gaussian Error Linear Unit (GELU).
5757
5858
https://arxiv.org/abs/1606.08415

cellseg_models_pytorch/modules/act/mish.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def mish(x: torch.Tensor):
5050

5151

5252
class Mish(nn.Module):
53-
def __init__(self, inplace: bool = False) -> None:
53+
def __init__(self, inplace: bool = False, **kwargs) -> None:
5454
"""
5555
Element-wise mish.
5656

cellseg_models_pytorch/modules/act/star_relu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def __init__(
1212
scale_learnable: bool = True,
1313
bias_learnable: bool = True,
1414
inplace: bool = False,
15+
**kwargs
1516
) -> None:
1617
"""Apply StarReLU activation.
1718

cellseg_models_pytorch/modules/act/swish.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def swish(x: torch.Tensor) -> torch.Tensor:
3939

4040

4141
class Swish(nn.Module):
42-
def __init__(self, inplace: bool = False) -> None:
42+
def __init__(self, inplace: bool = False, **kwargs) -> None:
4343
"""Apply the element-wise swish function.
4444
4545
Parameters

cellseg_models_pytorch/utils/img_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def normalize(
131131
im = im - im.mean(axis=axis, keepdims=True)
132132

133133
if standardize:
134-
im /= im.std(axis=axis, keepdims=True)
134+
im = im / (im.std(axis=axis, keepdims=True) + 1e-8)
135135

136136
# clamp
137137
if not any(x is None for x in (amin, amax)):
@@ -170,7 +170,7 @@ def minmax_normalize(
170170
)
171171

172172
im = img.copy()
173-
im = (im - im.min()) / (im.max() - im.min())
173+
im = (im - im.min()) / (im.max() - im.min() + 1e-8)
174174

175175
# clamp
176176
if not any(x is None for x in (amin, amax)):

0 commit comments

Comments
 (0)