Skip to content

Commit 2048e3c

Browse files
bythew3ijax authors
authored andcommitted
[Pallas] Add stride in Pallas dynamic slice and support strided load/store.
PiperOrigin-RevId: 615940113
1 parent 1cef1d9 commit 2048e3c

File tree

2 files changed

+96
-39
lines changed

2 files changed

+96
-39
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 72 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
from jax._src.util import safe_zip
5656
from jax._src.util import split_list
5757
from jax._src.util import unzip2
58-
from jax._src.util import unzip3
5958
from jax.experimental.mosaic.dialects import tpu
6059
import jax.numpy as jnp
6160
import numpy as np
@@ -746,47 +745,71 @@ def _maybe_cast_to_index(cast_to_index, x):
746745
return _make_index(x)
747746
return _ensure_mlir_value(x, aval=jax_core.ShapedArray((), jnp.int32))
748747

749-
def _index_to_start_size(idx: tuple[indexing.Slice | int | ir.Value, ...],
750-
cast_to_index: bool) -> tuple[ir.Value, int, bool]:
748+
749+
def _index_to_start_size_stride(
750+
idx: tuple[indexing.Slice | int | ir.Value, ...], cast_to_index: bool
751+
) -> tuple[ir.Value, int, int, bool]:
751752
assert not isinstance(idx, slice)
752753
if isinstance(idx, indexing.Slice):
753754
start = _maybe_cast_to_index(cast_to_index, idx.start)
754755
size = idx.size
756+
stride = idx.stride
755757
squeeze = False
756758
elif isinstance(idx, int):
757759
start = _maybe_cast_to_index(cast_to_index, idx)
758760
size = 1
761+
stride = 1
759762
squeeze = True
760763
else:
761764
if np.shape(idx):
762765
raise ValueError(f"Can only use ()-shaped and slice indexing: {idx}")
763766
start = _maybe_cast_to_index(cast_to_index, idx)
764767
size = 1
768+
stride = 1
765769
squeeze = True
766-
return start, size, squeeze
770+
return start, size, stride, squeeze
767771

768772

769-
def _indexer_to_start_size(
770-
indexer: NDIndexer, ref_block_shape: tuple[int | pl_core.Mapped, ...], *,
773+
def _indexer_to_start_size_stride(
774+
indexer: NDIndexer,
775+
ref_block_shape: tuple[int | pl_core.Mapped, ...],
776+
*,
771777
cast_to_index: bool,
772-
) -> tuple[tuple[ir.Value, ...], tuple[int, ...], tuple[bool, ...],
773-
tuple[int | pl_core.Mapped, ...]]:
778+
) -> tuple[
779+
tuple[ir.Value, ...],
780+
tuple[int, ...],
781+
tuple[int, ...],
782+
tuple[bool, ...],
783+
tuple[int | pl_core.Mapped, ...],
784+
]:
774785
indices_iter = iter(indexer.indices)
775-
starts, sizes, squeeze_dims = unzip3(
776-
(
777-
_maybe_cast_to_index(cast_to_index, 0),
778-
1,
779-
True,
780-
)
781-
if s is pl_core.mapped
782-
else _index_to_start_size(next(indices_iter), cast_to_index)
783-
for s in ref_block_shape
784-
)
786+
starts, sizes, strides, squeeze_dims = [], [], [], []
787+
for s in ref_block_shape:
788+
start, size, stride, squeeze_dim = (
789+
(
790+
_maybe_cast_to_index(cast_to_index, 0),
791+
1,
792+
1,
793+
True,
794+
)
795+
if s is pl_core.mapped
796+
else _index_to_start_size_stride(next(indices_iter), cast_to_index)
797+
)
798+
starts.append(start)
799+
sizes.append(size)
800+
strides.append(stride)
801+
squeeze_dims.append(squeeze_dim)
785802
next_index = next(indices_iter, None)
786803
assert next_index is None, (indexer.indices, ref_block_shape)
787804
new_ref_block_shape = tuple(s for s, squeeze in zip(sizes, squeeze_dims)
788805
if not squeeze)
789-
return tuple(starts), tuple(sizes), tuple(squeeze_dims), new_ref_block_shape
806+
return (
807+
tuple(starts),
808+
tuple(sizes),
809+
tuple(strides),
810+
tuple(squeeze_dims),
811+
new_ref_block_shape,
812+
)
790813

791814

792815
def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef,
@@ -796,9 +819,15 @@ def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef,
796819
tuple[int | pl_core.Mapped, ...]]:
797820
assert ref_block_shape is not None
798821
target_shape = indexer.get_indexer_shape()
799-
starts, sizes, squeeze_dims, ref_block_shape = _indexer_to_start_size(
800-
indexer, ref_block_shape, cast_to_index=False,
822+
starts, sizes, strides, squeeze_dims, ref_block_shape = (
823+
_indexer_to_start_size_stride(
824+
indexer,
825+
ref_block_shape,
826+
cast_to_index=False,
827+
)
801828
)
829+
if not all((s is None or s == 1) for s in strides):
830+
raise NotImplementedError("Strided slices of references are unsupported.")
802831
target_ref_ty = ir.MemRefType.get(
803832
tuple(sizes), _dtype_to_ir_type(ref_aval.dtype),
804833
memory_space=ref.type.memory_space)
@@ -846,14 +875,21 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_):
846875
for a in idx_aval.indices
847876
):
848877
raise ValueError("Cannot do int indexing on TPU")
849-
starts, sizes, _, _ = _indexer_to_start_size(
850-
idx, ref_block_shape, cast_to_index=True,
878+
starts, sizes, strides, _, _ = _indexer_to_start_size_stride(
879+
idx,
880+
ref_block_shape,
881+
cast_to_index=True,
851882
)
883+
need_stride = not all((s is None or s == 1) for s in strides)
852884
load_aval = jax_core.ShapedArray(sizes, dtype=ref_aval.dtype)
853885
if is_smem_load:
854886
if ctx.avals_out[0].shape:
855887
raise ValueError("Can only load scalars from SMEM")
856888
return memref.LoadOp(ref, starts).result
889+
if need_stride:
890+
load_val = tpu.StridedLoadOp(
891+
aval_to_ir_type(load_aval), ref, starts, strides
892+
).result
857893
else:
858894
load_val = vector.LoadOp(aval_to_ir_type(load_aval), ref, starts).result
859895
if load_aval == aval_out:
@@ -896,10 +932,12 @@ def _masked_swap_lowering_rule(
896932
raise NotImplementedError(
897933
"Indexing into a ()-shaped Ref not yet supported on TPU.")
898934

899-
starts, _, _, _ = _indexer_to_start_size(
900-
idx, ref_block_shape, cast_to_index=True,
935+
starts, _, strides, _, _ = _indexer_to_start_size_stride(
936+
idx,
937+
ref_block_shape,
938+
cast_to_index=True,
901939
)
902-
940+
need_stride = not all((s is None or s == 1) for s in strides)
903941
if is_smem_store:
904942
if val_aval.shape:
905943
raise ValueError("Can only store scalars to SMEM")
@@ -918,7 +956,10 @@ def _masked_swap_lowering_rule(
918956
mem_aval = aval_out.update(shape=tuple(mem_slice_shape))
919957
mem_aval_vec_type = ir.VectorType.get(mem_aval.shape,
920958
_dtype_to_ir_type(mem_aval.dtype))
921-
result = vector.LoadOp(mem_aval_vec_type, ref, starts).result
959+
if need_stride:
960+
result = tpu.StridedLoadOp(mem_aval_vec_type, ref, starts, strides).result
961+
else:
962+
result = vector.LoadOp(mem_aval_vec_type, ref, starts).result
922963
if mem_aval != aval_out:
923964
# We are slicing a scalar so provided dummy 1 indices
924965
result_vec_type = ir.VectorType.get(aval_out.shape,
@@ -927,7 +968,10 @@ def _masked_swap_lowering_rule(
927968
val_vec_type = ir.VectorType.get(mem_aval.shape,
928969
_dtype_to_ir_type(mem_aval.dtype))
929970
val = vector.ShapeCastOp(val_vec_type, val).result
930-
vector.StoreOp(val, ref, starts)
971+
if need_stride:
972+
tpu.StridedStoreOp(val, ref, starts, strides)
973+
else:
974+
vector.StoreOp(val, ref, starts)
931975
return result
932976

933977

jax/_src/state/indexing.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,19 @@ class Slice:
3333
"""Represents a slice with a dynamic start index and a fixed size."""
3434
start: Any
3535
size: int
36+
stride: int = 1
3637

3738
def __post_init__(self):
3839
if self.size < 0:
3940
raise ValueError("`size` must not be negative.")
41+
if self.stride < 1:
42+
raise ValueError("`stride` must be >= 1.")
4043

4144
def tree_flatten(self):
4245
# If `start` is statically known, we treat it as static information
4346
if isinstance(self.start, int):
44-
return (), (self.start, self.size)
45-
return (self.start,), (self.size,)
47+
return (), (self.start, self.size, self.stride)
48+
return (self.start,), (self.size, self.stride)
4649

4750
@classmethod
4851
def tree_unflatten(cls, aux_data, children) -> Slice:
@@ -51,21 +54,30 @@ def tree_unflatten(cls, aux_data, children) -> Slice:
5154
@classmethod
5255
def from_slice(cls, slc: slice, size: int) -> Slice:
5356
start, stop, step = slc.indices(size)
54-
if step != 1:
55-
raise ValueError(f"slice must have a step of 1 (found: {step})")
56-
return cls(start, max(stop - start, 0))
57+
if step < 1:
58+
raise ValueError(f"slice must have a step >= 1 (found: {step})")
59+
return cls(start, max((stop - start + step - 1) // step, 0), step)
5760

5861

59-
def dslice(start: int | Array | None, size: int | None = None
60-
) -> slice | Slice:
62+
def dslice(
63+
start: int | Array | None,
64+
size: int | None = None,
65+
stride: int | None = None,
66+
) -> slice | Slice:
6167
"""Constructs a `Slice` from a start and a size."""
6268
if start is None:
6369
return slice(None)
70+
if stride is None:
71+
stride = 1
72+
if not isinstance(stride, int):
73+
raise ValueError("Non-static stride in `dslice`")
6474
if size is None:
6575
if not isinstance(start, int):
6676
raise ValueError("Non-static `dslice`")
67-
return Slice(0, start)
68-
return Slice(start, size)
77+
return Slice(0, start, stride)
78+
return Slice(start, size, stride)
79+
80+
6981
ds = dslice # Handy alias
7082

7183

@@ -113,9 +125,10 @@ def __post_init__(self):
113125
if value := _maybe_concretize(start):
114126
if value >= s:
115127
raise ValueError(f"Out of bound slice: start={value}, dim={s}.")
116-
if value + idx.size > s:
128+
if value + (idx.size - 1) * idx.stride >= s:
117129
raise ValueError(
118-
f"Out of bound slice: start={value}, size={idx.size}, dim={s}."
130+
f"Out of bound slice: start={value}, size={idx.size},"
131+
f" stride={idx.stride}, dim={s}."
119132
)
120133
continue
121134
# The shape of indexer integers should be broadcastable up to the

0 commit comments

Comments
 (0)