Skip to content

Commit 29a2762

Browse files
author
jax authors
committed
Merge pull request #20558 from carlosgmartin:mish
PiperOrigin-RevId: 621708823
2 parents 5cbb26f + f0314c7 commit 29a2762

File tree

4 files changed

+52
-3
lines changed

4 files changed

+52
-3
lines changed

docs/jax.nn.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ Activation functions
3838
gelu
3939
glu
4040
squareplus
41+
mish
4142

4243
Other functions
4344
---------------

jax/_src/nn/functions.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,29 @@ def silu(x: ArrayLike) -> Array:
199199

200200
swish = silu
201201

202+
@jax.jit
203+
def mish(x: ArrayLike) -> Array:
204+
r"""Mish activation function.
205+
206+
Computes the element-wise function:
207+
208+
.. math::
209+
\mathrm{mish}(x) = x \cdot \mathrm{tanh}(\mathrm{softplus}(x))
210+
211+
For more information, see
212+
`Mish: A Self Regularized Non-Monotonic Activation Function
213+
<https://arxiv.org/abs/1908.08681>`_.
214+
215+
Args:
216+
x : input array
217+
218+
Returns:
219+
An array.
220+
"""
221+
numpy_util.check_arraylike("mish", x)
222+
x_arr = jnp.asarray(x)
223+
return x_arr * jnp.tanh(softplus(x_arr))
224+
202225
@jax.jit
203226
def log_sigmoid(x: ArrayLike) -> Array:
204227
r"""Log-sigmoid activation function.
@@ -314,7 +337,7 @@ def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Array:
314337
315338
For more information, see
316339
`Continuously Differentiable Exponential Linear Units
317-
<https://arxiv.org/pdf/1704.07483.pdf>`_.
340+
<https://arxiv.org/abs/1704.07483>`_.
318341
319342
Args:
320343
x : input array
@@ -342,7 +365,7 @@ def selu(x: ArrayLike) -> Array:
342365
343366
For more information, see
344367
`Self-Normalizing Neural Networks
345-
<https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf>`_.
368+
<https://arxiv.org/abs/1706.02515>`_.
346369
347370
Args:
348371
x : input array

jax/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
silu as silu,
4646
swish as swish,
4747
squareplus as squareplus,
48+
mish as mish,
4849
)
4950

5051
# Deprecations

tests/nn_test.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,26 @@ def testSquareplusGradNan(self):
9191
def testSquareplusZero(self, dtype):
9292
self.assertEqual(dtype(1), nn.squareplus(dtype(0), dtype(4)))
9393

94+
def testMishGrad(self):
95+
check_grads(nn.mish, (1e-8,), order=4,
96+
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
97+
98+
def testMishGradZero(self):
99+
check_grads(nn.mish, (0.,), order=1,
100+
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
101+
102+
def testMishGradNegInf(self):
103+
check_grads(nn.mish, (-float('inf'),), order=1,
104+
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
105+
106+
def testMishGradNan(self):
107+
check_grads(nn.mish, (float('nan'),), order=1,
108+
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
109+
110+
@parameterized.parameters([float] + jtu.dtypes.floating)
111+
def testMishZero(self, dtype):
112+
self.assertEqual(dtype(0), nn.mish(dtype(0)))
113+
94114
def testReluGrad(self):
95115
rtol = 1e-2 if jtu.test_device_matches(["tpu"]) else None
96116
check_grads(nn.relu, (1.,), order=3, rtol=rtol)
@@ -117,6 +137,10 @@ def testSquareplusValue(self):
117137
val = nn.squareplus(1e3)
118138
self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
119139

140+
def testMishValue(self):
141+
val = nn.mish(1e3)
142+
self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
143+
120144
@jtu.skip_on_flag("jax_skip_slow_tests", True)
121145
def testEluGrad(self):
122146
check_grads(nn.elu, (1e4,), order=4, eps=1.)
@@ -149,7 +173,7 @@ def gelu_reference(x):
149173
(jnp.float32, jnp.bfloat16, jnp.float16),
150174
(partial(nn.gelu, approximate=False),
151175
partial(nn.gelu, approximate=True),
152-
nn.relu, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus)))
176+
nn.relu, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus, nn.mish)))
153177
def testDtypeMatchesInput(self, dtype, fn):
154178
x = jnp.zeros((), dtype=dtype)
155179
out = fn(x)

0 commit comments

Comments
 (0)