Skip to content

Commit b729300

Browse files
author
jax authors
committed
Merge pull request #20762 from j-towns:scatter-doc-correction
PiperOrigin-RevId: 624971136
2 parents ac1a53d + b278312 commit b729300

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

jax/_src/lax/slicing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ class ScatterDimensionNumbers(NamedTuple):
370370
are the mirror image of `collapsed_slice_dims` in the case of `gather`.
371371
scatter_dims_to_operand_dims: for each dimension in `scatter_indices`, gives
372372
the corresponding dimension in `operand`. Must be a sequence of integers
373-
with size equal to indices.shape[-1].
373+
with size equal to `scatter_indices.shape[-1]`.
374374
375375
Unlike XLA's `ScatterDimensionNumbers` structure, `index_vector_dim` is
376376
implicit; there is always an index vector dimension and it must always be the

0 commit comments

Comments
 (0)