Skip to content

Commit 0b602c5

Browse files
mtthssjax authors
authored andcommitted
Add sparse_sigmoid to jax.nn
PiperOrigin-RevId: 623108517
1 parent 4d4151d commit 0b602c5

File tree

3 files changed

+49
-0
lines changed

3 files changed

+49
-0
lines changed

jax/_src/nn/functions.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,38 @@ def sigmoid(x: ArrayLike) -> Array:
173173
"""
174174
return lax.logistic(x)
175175

176+
@jax.jit
177+
def sparse_sigmoid(x: ArrayLike) -> Array:
178+
r"""Sparse sigmoid activation function.
179+
180+
Computes the function:
181+
182+
.. math::
183+
184+
\mathrm{sparse\_sigmoid}(x) = \begin{cases}
185+
0, & x \leq -1\\
186+
\frac{1}{2}(x+1), & -1 < x < 1 \\
187+
1, & 1 \leq x
188+
\end{cases}
189+
190+
This is the twin function of the ``sigmoid`` activation ensuring a zero output
191+
for inputs less than -1, a 1 ouput for inputs greater than 1, and a linear
192+
output for inputs between -1 and 1. It is the derivative of ``sparse_plus``.
193+
194+
For more information, see `Learning with Fenchel-Young Losses (section 6.2)
195+
<https://arxiv.org/abs/1901.02324>`_.
196+
197+
Args:
198+
x : input array
199+
200+
Returns:
201+
An array.
202+
203+
See also:
204+
:func:`sigmoid`
205+
"""
206+
return 0.5 * jnp.clip(x + 1.0, 0.0, 2.0)
207+
176208
@jax.jit
177209
def silu(x: ArrayLike) -> Array:
178210
r"""SiLU (aka swish) activation function.

jax/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
softmax as softmax,
4343
softplus as softplus,
4444
sparse_plus as sparse_plus,
45+
sparse_sigmoid as sparse_sigmoid,
4546
silu as silu,
4647
swish as swish,
4748
squareplus as squareplus,

tests/nn_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,17 @@ def testSparseplusGrad(self):
7171
check_grads(nn.sparse_plus, (0.,), order=1,
7272
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
7373

74+
def testSparseplusAndSparseSigmoid(self):
75+
self.assertAllClose(
76+
jax.grad(nn.sparse_plus)(0.), nn.sparse_sigmoid(0.),
77+
check_dtypes=False)
78+
self.assertAllClose(
79+
jax.grad(nn.sparse_plus)(2.), nn.sparse_sigmoid(2.),
80+
check_dtypes=False)
81+
self.assertAllClose(
82+
jax.grad(nn.sparse_plus)(-2.), nn.sparse_sigmoid(-2.),
83+
check_dtypes=False)
84+
7485
def testSquareplusGrad(self):
7586
check_grads(nn.squareplus, (1e-8,), order=4,
7687
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
@@ -133,6 +144,11 @@ def testSparseplusValue(self):
133144
val = nn.sparse_plus(89.)
134145
self.assertAllClose(val, 89., check_dtypes=False)
135146

147+
def testSparsesigmoidValue(self):
148+
self.assertAllClose(nn.sparse_sigmoid(-2.), 0., check_dtypes=False)
149+
self.assertAllClose(nn.sparse_sigmoid(2.), 1., check_dtypes=False)
150+
self.assertAllClose(nn.sparse_sigmoid(0.), .5, check_dtypes=False)
151+
136152
def testSquareplusValue(self):
137153
val = nn.squareplus(1e3)
138154
self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)

0 commit comments

Comments
 (0)