Skip to content

Commit b48aec5

Browse files
Jake VanderPlasjax authors
authored andcommitted
Require array-like inputs to sparse_plus
We should not silently convert non-array inputs to arrays, because this can lead to silent performance degredation. This brings the sparse_plus API in line with other APIs in this module. PiperOrigin-RevId: 617190413
1 parent 0b28a4b commit b48aec5

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

jax/_src/nn/functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def sparse_plus(x: ArrayLike) -> Array:
132132
Args:
133133
x: input (float)
134134
"""
135+
numpy_util.check_arraylike("sparse_plus", x)
135136
x = jnp.asarray(x)
136137
return jnp.where(x <= -1.0, 0.0, jnp.where(x >= 1.0, x, (x + 1.0)**2/4))
137138

0 commit comments

Comments
 (0)