We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents ac1a53d + b278312 commit b729300Copy full SHA for b729300
jax/_src/lax/slicing.py
@@ -370,7 +370,7 @@ class ScatterDimensionNumbers(NamedTuple):
370
are the mirror image of `collapsed_slice_dims` in the case of `gather`.
371
scatter_dims_to_operand_dims: for each dimension in `scatter_indices`, gives
372
the corresponding dimension in `operand`. Must be a sequence of integers
373
- with size equal to indices.shape[-1].
+ with size equal to `scatter_indices.shape[-1]`.
374
375
Unlike XLA's `ScatterDimensionNumbers` structure, `index_vector_dim` is
376
implicit; there is always an index vector dimension and it must always be the
0 commit comments