Skip to content

Commit c94ea14

Browse files
mtthssjax authors
authored andcommitted
Add sparseplus activation to jax.nn.
PiperOrigin-RevId: 616087452
1 parent a5d32c4 commit c94ea14

File tree

4 files changed

+40
-1
lines changed

4 files changed

+40
-1
lines changed

docs/jax.nn.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Activation functions
2222
relu6
2323
sigmoid
2424
softplus
25+
sparse_plus
2526
soft_sign
2627
silu
2728
swish

jax/_src/nn/functions.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,31 @@ def softplus(x: ArrayLike) -> Array:
110110
"""
111111
return jnp.logaddexp(x, 0)
112112

113+
@jax.jit
114+
def sparse_plus(x: ArrayLike) -> Array:
115+
r"""Sparse plus function.
116+
117+
Computes the function:
118+
119+
.. math::
120+
121+
\mathrm{sparse\_plus}(x) = \begin{cases}
122+
0, & x \leq -1\\
123+
\frac{1}{4}(x+1)^2, & -1 < x < 1 \\
124+
x, & 1 \leq x
125+
\end{cases}
126+
127+
This is the twin function of the softplus activation ensuring a zero output
128+
for inputs less than -1 and a linear output for inputs greater than 1,
129+
while remaining smooth, convex, monotonic by an adequate definition between
130+
-1 and 1.
131+
132+
Args:
133+
x: input (float)
134+
"""
135+
x = jnp.asarray(x)
136+
return jnp.where(x <= -1.0, 0.0, jnp.where(x >= 1.0, x, (x + 1.0)**2/4))
137+
113138
@jax.jit
114139
def soft_sign(x: ArrayLike) -> Array:
115140
r"""Soft-sign activation function.

jax/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
soft_sign as soft_sign,
4242
softmax as softmax,
4343
softplus as softplus,
44+
sparse_plus as sparse_plus,
4445
silu as silu,
4546
swish as swish,
4647
squareplus as squareplus,

tests/nn_test.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@ def testSoftplusGradNan(self):
6363
def testSoftplusZero(self, dtype):
6464
self.assertEqual(jnp.log(dtype(2)), nn.softplus(dtype(0)))
6565

66+
def testSparseplusGradZero(self):
67+
check_grads(nn.sparse_plus, (-2.,), order=1,
68+
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
69+
70+
def testSparseplusGrad(self):
71+
check_grads(nn.sparse_plus, (0.,), order=1,
72+
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
73+
6674
def testSquareplusGrad(self):
6775
check_grads(nn.squareplus, (1e-8,), order=4,
6876
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
@@ -101,6 +109,10 @@ def testSoftplusValue(self):
101109
val = nn.softplus(89.)
102110
self.assertAllClose(val, 89., check_dtypes=False)
103111

112+
def testSparseplusValue(self):
113+
val = nn.sparse_plus(89.)
114+
self.assertAllClose(val, 89., check_dtypes=False)
115+
104116
def testSquareplusValue(self):
105117
val = nn.squareplus(1e3)
106118
self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
@@ -137,7 +149,7 @@ def gelu_reference(x):
137149
(jnp.float32, jnp.bfloat16, jnp.float16),
138150
(partial(nn.gelu, approximate=False),
139151
partial(nn.gelu, approximate=True),
140-
nn.relu, nn.softplus, nn.sigmoid, nn.squareplus)))
152+
nn.relu, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus)))
141153
def testDtypeMatchesInput(self, dtype, fn):
142154
x = jnp.zeros((), dtype=dtype)
143155
out = fn(x)

0 commit comments

Comments
 (0)