55
55
from jax ._src .util import safe_zip
56
56
from jax ._src .util import split_list
57
57
from jax ._src .util import unzip2
58
- from jax ._src .util import unzip3
59
58
from jax .experimental .mosaic .dialects import tpu
60
59
import jax .numpy as jnp
61
60
import numpy as np
@@ -746,47 +745,71 @@ def _maybe_cast_to_index(cast_to_index, x):
746
745
return _make_index (x )
747
746
return _ensure_mlir_value (x , aval = jax_core .ShapedArray ((), jnp .int32 ))
748
747
749
- def _index_to_start_size (idx : tuple [indexing .Slice | int | ir .Value , ...],
750
- cast_to_index : bool ) -> tuple [ir .Value , int , bool ]:
748
+
749
+ def _index_to_start_size_stride (
750
+ idx : tuple [indexing .Slice | int | ir .Value , ...], cast_to_index : bool
751
+ ) -> tuple [ir .Value , int , int , bool ]:
751
752
assert not isinstance (idx , slice )
752
753
if isinstance (idx , indexing .Slice ):
753
754
start = _maybe_cast_to_index (cast_to_index , idx .start )
754
755
size = idx .size
756
+ stride = idx .stride
755
757
squeeze = False
756
758
elif isinstance (idx , int ):
757
759
start = _maybe_cast_to_index (cast_to_index , idx )
758
760
size = 1
761
+ stride = 1
759
762
squeeze = True
760
763
else :
761
764
if np .shape (idx ):
762
765
raise ValueError (f"Can only use ()-shaped and slice indexing: { idx } " )
763
766
start = _maybe_cast_to_index (cast_to_index , idx )
764
767
size = 1
768
+ stride = 1
765
769
squeeze = True
766
- return start , size , squeeze
770
+ return start , size , stride , squeeze
767
771
768
772
769
- def _indexer_to_start_size (
770
- indexer : NDIndexer , ref_block_shape : tuple [int | pl_core .Mapped , ...], * ,
773
+ def _indexer_to_start_size_stride (
774
+ indexer : NDIndexer ,
775
+ ref_block_shape : tuple [int | pl_core .Mapped , ...],
776
+ * ,
771
777
cast_to_index : bool ,
772
- ) -> tuple [tuple [ir .Value , ...], tuple [int , ...], tuple [bool , ...],
773
- tuple [int | pl_core .Mapped , ...]]:
778
+ ) -> tuple [
779
+ tuple [ir .Value , ...],
780
+ tuple [int , ...],
781
+ tuple [int , ...],
782
+ tuple [bool , ...],
783
+ tuple [int | pl_core .Mapped , ...],
784
+ ]:
774
785
indices_iter = iter (indexer .indices )
775
- starts , sizes , squeeze_dims = unzip3 (
776
- (
777
- _maybe_cast_to_index (cast_to_index , 0 ),
778
- 1 ,
779
- True ,
780
- )
781
- if s is pl_core .mapped
782
- else _index_to_start_size (next (indices_iter ), cast_to_index )
783
- for s in ref_block_shape
784
- )
786
+ starts , sizes , strides , squeeze_dims = [], [], [], []
787
+ for s in ref_block_shape :
788
+ start , size , stride , squeeze_dim = (
789
+ (
790
+ _maybe_cast_to_index (cast_to_index , 0 ),
791
+ 1 ,
792
+ 1 ,
793
+ True ,
794
+ )
795
+ if s is pl_core .mapped
796
+ else _index_to_start_size_stride (next (indices_iter ), cast_to_index )
797
+ )
798
+ starts .append (start )
799
+ sizes .append (size )
800
+ strides .append (stride )
801
+ squeeze_dims .append (squeeze_dim )
785
802
next_index = next (indices_iter , None )
786
803
assert next_index is None , (indexer .indices , ref_block_shape )
787
804
new_ref_block_shape = tuple (s for s , squeeze in zip (sizes , squeeze_dims )
788
805
if not squeeze )
789
- return tuple (starts ), tuple (sizes ), tuple (squeeze_dims ), new_ref_block_shape
806
+ return (
807
+ tuple (starts ),
808
+ tuple (sizes ),
809
+ tuple (strides ),
810
+ tuple (squeeze_dims ),
811
+ new_ref_block_shape ,
812
+ )
790
813
791
814
792
815
def _slice_memref (ref : ir .Value , ref_aval : state .AbstractRef ,
@@ -796,9 +819,15 @@ def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef,
796
819
tuple [int | pl_core .Mapped , ...]]:
797
820
assert ref_block_shape is not None
798
821
target_shape = indexer .get_indexer_shape ()
799
- starts , sizes , squeeze_dims , ref_block_shape = _indexer_to_start_size (
800
- indexer , ref_block_shape , cast_to_index = False ,
822
+ starts , sizes , strides , squeeze_dims , ref_block_shape = (
823
+ _indexer_to_start_size_stride (
824
+ indexer ,
825
+ ref_block_shape ,
826
+ cast_to_index = False ,
827
+ )
801
828
)
829
+ if not all ((s is None or s == 1 ) for s in strides ):
830
+ raise NotImplementedError ("Strided slices of references are unsupported." )
802
831
target_ref_ty = ir .MemRefType .get (
803
832
tuple (sizes ), _dtype_to_ir_type (ref_aval .dtype ),
804
833
memory_space = ref .type .memory_space )
@@ -846,14 +875,21 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_):
846
875
for a in idx_aval .indices
847
876
):
848
877
raise ValueError ("Cannot do int indexing on TPU" )
849
- starts , sizes , _ , _ = _indexer_to_start_size (
850
- idx , ref_block_shape , cast_to_index = True ,
878
+ starts , sizes , strides , _ , _ = _indexer_to_start_size_stride (
879
+ idx ,
880
+ ref_block_shape ,
881
+ cast_to_index = True ,
851
882
)
883
+ need_stride = not all ((s is None or s == 1 ) for s in strides )
852
884
load_aval = jax_core .ShapedArray (sizes , dtype = ref_aval .dtype )
853
885
if is_smem_load :
854
886
if ctx .avals_out [0 ].shape :
855
887
raise ValueError ("Can only load scalars from SMEM" )
856
888
return memref .LoadOp (ref , starts ).result
889
+ if need_stride :
890
+ load_val = tpu .StridedLoadOp (
891
+ aval_to_ir_type (load_aval ), ref , starts , strides
892
+ ).result
857
893
else :
858
894
load_val = vector .LoadOp (aval_to_ir_type (load_aval ), ref , starts ).result
859
895
if load_aval == aval_out :
@@ -896,10 +932,12 @@ def _masked_swap_lowering_rule(
896
932
raise NotImplementedError (
897
933
"Indexing into a ()-shaped Ref not yet supported on TPU." )
898
934
899
- starts , _ , _ , _ = _indexer_to_start_size (
900
- idx , ref_block_shape , cast_to_index = True ,
935
+ starts , _ , strides , _ , _ = _indexer_to_start_size_stride (
936
+ idx ,
937
+ ref_block_shape ,
938
+ cast_to_index = True ,
901
939
)
902
-
940
+ need_stride = not all (( s is None or s == 1 ) for s in strides )
903
941
if is_smem_store :
904
942
if val_aval .shape :
905
943
raise ValueError ("Can only store scalars to SMEM" )
@@ -918,7 +956,10 @@ def _masked_swap_lowering_rule(
918
956
mem_aval = aval_out .update (shape = tuple (mem_slice_shape ))
919
957
mem_aval_vec_type = ir .VectorType .get (mem_aval .shape ,
920
958
_dtype_to_ir_type (mem_aval .dtype ))
921
- result = vector .LoadOp (mem_aval_vec_type , ref , starts ).result
959
+ if need_stride :
960
+ result = tpu .StridedLoadOp (mem_aval_vec_type , ref , starts , strides ).result
961
+ else :
962
+ result = vector .LoadOp (mem_aval_vec_type , ref , starts ).result
922
963
if mem_aval != aval_out :
923
964
# We are slicing a scalar so provided dummy 1 indices
924
965
result_vec_type = ir .VectorType .get (aval_out .shape ,
@@ -927,7 +968,10 @@ def _masked_swap_lowering_rule(
927
968
val_vec_type = ir .VectorType .get (mem_aval .shape ,
928
969
_dtype_to_ir_type (mem_aval .dtype ))
929
970
val = vector .ShapeCastOp (val_vec_type , val ).result
930
- vector .StoreOp (val , ref , starts )
971
+ if need_stride :
972
+ tpu .StridedStoreOp (val , ref , starts , strides )
973
+ else :
974
+ vector .StoreOp (val , ref , starts )
931
975
return result
932
976
933
977
0 commit comments