Skip to content

Commit 7d431ad

Browse files
apaszkejax authors
authored andcommitted
Add support for slicing dynamically-shaped memrefs + DMAs between them
This was a little difficult because our current dialect conversion setup assumes 1-1 type conversions. I think everything works out fine for as long as we never pass memrefs between basic blocks (i.e. for as long as we never have memrefs as loop carry or return them from conditionals). TODO: I still need to make sure that the changes to the TPU dialect are backwards-compatible. I am afraid that the signature change in MemRefSliceOp might not be. PiperOrigin-RevId: 617755035
1 parent 8920b54 commit 7d431ad

File tree

4 files changed

+12
-8
lines changed

4 files changed

+12
-8
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,7 @@ def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef,
833833
memory_space=ref.type.memory_space)
834834
inner_aval = ref_aval.inner_aval
835835
out_aval = ref_aval.update(inner_aval=inner_aval.update(shape=target_shape))
836-
out = tpu.MemRefSliceOp(target_ref_ty, ref, starts).result
836+
out = tpu.MemRefSliceOp(target_ref_ty, ref, starts, []).result
837837
if any(squeeze_dims):
838838
# We need to squeeze out some dimensions
839839
squeezed_ref_ty = ir.MemRefType.get(

jaxlib/mosaic/dialect/tpu/tpu.td

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,14 +397,16 @@ def TPU_AssumeMultipleOp : TPU_Op<"assume_multiple", [Pure, SameOperandsAndResul
397397
let results = (outs AnyTypeOf<[Index, AnyInteger]>:$result);
398398
}
399399

400-
def TPU_MemRefSliceOp : TPU_Op<"memref_slice", [Pure]> {
400+
def TPU_MemRefSliceOp : TPU_Op<"memref_slice", [Pure, AttrSizedOperandSegments]> {
401401
let arguments = (ins
402402
AnyMemRef:$mem_ref,
403-
Variadic<I32>:$base_idx
403+
Variadic<I32>:$base_idx,
404+
Variadic<I32>:$dynamic_sizes
404405
);
405406
let results = (outs AnyMemRef:$result);
406407
let assemblyFormat = [{
407-
$mem_ref `[` $base_idx `]` attr-dict `:` type($mem_ref) `->` type($result)
408+
$mem_ref `[` $base_idx `]` (`<` $dynamic_sizes^ `>`)?
409+
attr-dict `:` type($mem_ref) `->` type($result)
408410
}];
409411
let hasVerifier = 1;
410412
let hasCanonicalizeMethod = 1;

jaxlib/mosaic/dialect/tpu/tpu_ops.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ LogicalResult MemRefSliceOp::verify() {
6666
(target_memory_space == nullptr ||
6767
target_memory_space == source_type.getMemorySpace()) &&
6868
((isa<AffineMapAttr>(target_layout) && target_layout.isIdentity()) ||
69-
target_type.getLayout() == source_type.getLayout()));
69+
target_type.getLayout() == source_type.getLayout()) &&
70+
getDynamicSizes().size() == target_type.getNumDynamicDims());
7071
}
7172

7273
LogicalResult MemRefSliceOp::canonicalize(MemRefSliceOp op,
@@ -82,8 +83,9 @@ LogicalResult MemRefSliceOp::canonicalize(MemRefSliceOp op,
8283
auto new_result_type = MemRefType::get(
8384
op.getResult().getType().getShape(), layout_ty.getElementType(),
8485
layout_ty.getLayout(), layout_ty.getMemorySpace());
85-
auto slice = rewriter.create<MemRefSliceOp>(op.getLoc(), new_result_type,
86-
layout_ref, op.getBaseIdx());
86+
auto slice =
87+
rewriter.create<MemRefSliceOp>(op.getLoc(), new_result_type, layout_ref,
88+
op.getBaseIdx(), op.getDynamicSizes());
8789
rewriter.replaceOpWithNewOp<EraseLayoutOp>(op, op.getType(), slice);
8890
return success();
8991
}

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ FailureOr<std::pair<Value, SmallVector<int64_t>>> sliceRef(
367367
Value sliced_ref = builder.create<tpu::MemRefSliceOp>(
368368
MemRefType::get(pad_slice_shape, ref_ty.getElementType(),
369369
ref_ty.getLayout(), ref_ty.getMemorySpace()),
370-
base_ref, slice_base_indices);
370+
base_ref, slice_base_indices, /*dynamic_sizes=*/ValueRange());
371371

372372
return std::make_pair(sliced_ref, indices_within_slice);
373373
}

0 commit comments

Comments
 (0)