Skip to content

Commit d6d8914

Browse files
committed
DOC: jax.lax.top_k: fix rendering of return values
1 parent 27de854 commit d6d8914

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

jax/_src/lax/lax.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,8 +1232,10 @@ def top_k(operand: ArrayLike, k: int) -> tuple[Array, Array]:
12321232
k: integer specifying the number of top entries.
12331233
12341234
Returns:
1235-
values: array containing the top k values along the last axis.
1236-
indices: array containing the indices corresponding to values.
1235+
A tuple ``(values, indices)`` where
1236+
1237+
- ``values`` is an array containing the top k values along the last axis.
1238+
- ``indices`` is an array containing the indices corresponding to values.
12371239
12381240
See also:
12391241
- :func:`jax.lax.approx_max_k`

0 commit comments

Comments
 (0)