Skip to content

Commit 69795eb

Browse files
bythew3ijax authors
authored andcommitted
[Pallas] Raise NotImplementedError for strided load/store in interpret mode.
PiperOrigin-RevId: 615983065
1 parent 2048e3c commit 69795eb

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

jax/_src/pallas/primitives.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,10 @@ def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
303303
raise NotImplementedError("Only one indexer supported in discharge rule.")
304304
idx = indexers[0]
305305
if all((isinstance(s, Slice) or not s.shape) for s in idx.indices):
306+
# TODO(b/329733289): support strided load/store in interpret mode.
307+
for s in idx.indices:
308+
if isinstance(s, Slice) and s.stride > 1:
309+
raise NotImplementedError("Unimplemented stride support.")
306310
indices = idx.indices
307311
scalar_dims = [not isinstance(s, Slice) and not s.shape for s in indices]
308312
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
@@ -404,6 +408,10 @@ def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
404408
raise NotImplementedError("Only one indexer supported in discharge rule.")
405409
idx = indexers[0]
406410
if all((isinstance(s, Slice) or not s.shape) for s in idx.indices):
411+
# TODO(b/329733289): support strided load/store in interpret mode.
412+
for s in idx.indices:
413+
if isinstance(s, Slice) and s.stride > 1:
414+
raise NotImplementedError("Unimplemented stride support.")
407415
indices = idx.indices
408416
scalar_dims = [
409417
i

jax/_src/state/discharge.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,10 @@ def _maybe_convert_to_dynamic_slice(
194194
if not all(isinstance(i, indexing.Slice) or not np.shape(i)
195195
for i in indexer.indices):
196196
return None
197+
# TODO(b/329733289): support strided load/store in interpret mode.
198+
for i in indexer.indices:
199+
if isinstance(i, indexing.Slice) and i.stride > 1:
200+
raise NotImplementedError("Unimplemented stride support.")
197201
_convert_i32 = lambda x: lax.convert_element_type(x, np.dtype("int32"))
198202
starts = tuple(
199203
_convert_i32(i.start) if isinstance(i, indexing.Slice)

0 commit comments

Comments
 (0)