Skip to content

Commit 7b486f4

Browse files
author
jax authors
committed
Merge pull request #20643 from carlosgmartin:softmax_initial
PiperOrigin-RevId: 622949028
2 parents 0dd2408 + 9c347b9 commit 7b486f4

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

jax/_src/nn/functions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def glu(x: ArrayLike, axis: int = -1) -> Array:
454454
def log_softmax(x: ArrayLike,
455455
axis: int | tuple[int, ...] | None = -1,
456456
where: ArrayLike | None = None,
457-
initial: ArrayLike | None = None) -> Array:
457+
initial: ArrayLike | None = -jnp.inf) -> Array:
458458
r"""Log-Softmax function.
459459
460460
Computes the logarithm of the :code:`softmax` function, which rescales
@@ -496,7 +496,7 @@ def log_softmax(x: ArrayLike,
496496
def softmax(x: ArrayLike,
497497
axis: int | tuple[int, ...] | None = -1,
498498
where: ArrayLike | None = None,
499-
initial: ArrayLike | None = None) -> Array:
499+
initial: ArrayLike | None = -jnp.inf) -> Array:
500500
r"""Softmax function.
501501
502502
Computes the function which rescales elements to the range :math:`[0, 1]`
@@ -534,7 +534,7 @@ def _softmax(
534534
x: ArrayLike,
535535
axis: int | tuple[int, ...] | None = -1,
536536
where: ArrayLike | None = None,
537-
initial: ArrayLike | None = None) -> Array:
537+
initial: ArrayLike | None = -jnp.inf) -> Array:
538538
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
539539
x_safe = x if where is None else jnp.where(where, x, initial)
540540
unnormalized = jnp.exp(x_safe - x_max)
@@ -553,7 +553,7 @@ def _softmax_deprecated(
553553
x: ArrayLike,
554554
axis: int | tuple[int, ...] | None = -1,
555555
where: ArrayLike | None = None,
556-
initial: ArrayLike | None = None) -> Array:
556+
initial: ArrayLike | None = -jnp.inf) -> Array:
557557
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
558558
x_safe = x if where is None else jnp.where(where, x, initial)
559559
unnormalized = jnp.exp(x_safe - lax.stop_gradient(x_max))

0 commit comments

Comments
 (0)