@@ -454,7 +454,7 @@ def glu(x: ArrayLike, axis: int = -1) -> Array:
454
454
def log_softmax (x : ArrayLike ,
455
455
axis : int | tuple [int , ...] | None = - 1 ,
456
456
where : ArrayLike | None = None ,
457
- initial : ArrayLike | None = None ) -> Array :
457
+ initial : ArrayLike | None = - jnp . inf ) -> Array :
458
458
r"""Log-Softmax function.
459
459
460
460
Computes the logarithm of the :code:`softmax` function, which rescales
@@ -496,7 +496,7 @@ def log_softmax(x: ArrayLike,
496
496
def softmax (x : ArrayLike ,
497
497
axis : int | tuple [int , ...] | None = - 1 ,
498
498
where : ArrayLike | None = None ,
499
- initial : ArrayLike | None = None ) -> Array :
499
+ initial : ArrayLike | None = - jnp . inf ) -> Array :
500
500
r"""Softmax function.
501
501
502
502
Computes the function which rescales elements to the range :math:`[0, 1]`
@@ -534,7 +534,7 @@ def _softmax(
534
534
x : ArrayLike ,
535
535
axis : int | tuple [int , ...] | None = - 1 ,
536
536
where : ArrayLike | None = None ,
537
- initial : ArrayLike | None = None ) -> Array :
537
+ initial : ArrayLike | None = - jnp . inf ) -> Array :
538
538
x_max = jnp .max (x , axis , where = where , initial = initial , keepdims = True )
539
539
x_safe = x if where is None else jnp .where (where , x , initial )
540
540
unnormalized = jnp .exp (x_safe - x_max )
@@ -553,7 +553,7 @@ def _softmax_deprecated(
553
553
x : ArrayLike ,
554
554
axis : int | tuple [int , ...] | None = - 1 ,
555
555
where : ArrayLike | None = None ,
556
- initial : ArrayLike | None = None ) -> Array :
556
+ initial : ArrayLike | None = - jnp . inf ) -> Array :
557
557
x_max = jnp .max (x , axis , where = where , initial = initial , keepdims = True )
558
558
x_safe = x if where is None else jnp .where (where , x , initial )
559
559
unnormalized = jnp .exp (x_safe - lax .stop_gradient (x_max ))
0 commit comments