Skip to content

[mlir][vector] Improve shape_cast lowering #140800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 48 additions & 32 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,20 @@ using namespace mlir;
using namespace mlir::vector;

/// Increments n-D `indices` by `step` starting from the innermost dimension.
static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
static void incIdx(MutableArrayRef<int64_t> indices, ArrayRef<int64_t> shape,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps just a slow day for me, but it took me a while to follow the logic in this method. Let me share my observations:

  • step only really defines the step for the trailing dim? If yes, it would be good to update the variable name.
  • spill is either 0 or 1.

is this correct?

Btw, extra documentation would help. My initial interpretation was: "Update every single index by 1", but that's not true, is it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've refactored and documented this significantly in the latest commit, it is hopefully now clearer

int step = 1) {
for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
assert(indices[dim] < vecType.getDimSize(dim) &&
"Indices are out of bound");
int64_t dimSize = shape[dim];
assert(indices[dim] < dimSize && "Indices are out of bound");

indices[dim] += step;
if (indices[dim] < vecType.getDimSize(dim))

int64_t spill = indices[dim] / dimSize;
if (spill == 0)
break;

indices[dim] = 0;
step = 1;
indices[dim] %= dimSize;
step = spill;
}
}

Expand Down Expand Up @@ -79,8 +82,8 @@ class ShapeCastOpNDDownCastRewritePattern
// and destination slice insertion and generate such instructions.
for (int64_t i = 0; i < numElts; ++i) {
if (i != 0) {
incIdx(srcIdx, sourceVectorType, /*step=*/1);
incIdx(resIdx, resultVectorType, /*step=*/extractSize);
incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/1);
incIdx(resIdx, resultVectorType.getShape(), /*step=*/extractSize);
}

Value extract =
Expand Down Expand Up @@ -131,8 +134,8 @@ class ShapeCastOpNDUpCastRewritePattern
Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
for (int64_t i = 0; i < numElts; ++i) {
if (i != 0) {
incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
incIdx(resIdx, resultVectorType, /*step=*/1);
incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/extractSize);
incIdx(resIdx, resultVectorType.getShape(), /*step=*/1);
}

Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
Expand All @@ -157,41 +160,54 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto sourceVectorType = op.getSourceVectorType();
auto resultVectorType = op.getResultVectorType();
VectorType sourceType = op.getSourceVectorType();
VectorType resultType = op.getResultVectorType();

if (sourceVectorType.isScalable() || resultVectorType.isScalable())
if (sourceType.isScalable() || resultType.isScalable())
return failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you are refactoring a fair bit, would you mind replacing this (and other instances of failure) with notifyMatchFailure? Thanks!


// Special case for n-D / 1-D lowerings with better implementations.
int64_t srcRank = sourceVectorType.getRank();
int64_t resRank = resultVectorType.getRank();
if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1))
// Special case for n-D / 1-D lowerings with implementations that use
// extract_strided_slice / insert_strided_slice.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we clarify this comment? (the original part is quite confusing). Right now, combined with the code, it reads a bit like:

This is a special case, lets fail!

😅 I assume that it was meant to be:

This special case is handled by other, more optimal patterns.

Or something similar :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't much like the logic spread over the 3 patterns (N->N, 1->N, N->1) as there isn't really anything special about then 1->N and N->1 cases. So I've done a fairly major update to the N->N pattern, so that now it handles then 1->N and N->1 cases. As the new change is quite significant, if you'd prefer it to be done in a separate PR I'm happy to postpone this 'unification', backtrack, and just make the minor suggestions to this PR that you suggested.

I also unified the tests across the test file. The behavior for the 1->N and N->1 cases is unchanged by this PR though.

int64_t sourceRank = sourceType.getRank();
int64_t resultRank = resultType.getRank();
if ((sourceRank > 1 && resultRank == 1) ||
(sourceRank == 1 && resultRank > 1))
return failure();

// Generic ShapeCast lowering path goes all the way down to unrolled scalar
// extract/insert chains.
int64_t numElts = 1;
for (int64_t r = 0; r < srcRank; r++)
numElts *= sourceVectorType.getDimSize(r);
int64_t numExtracts = sourceType.getNumElements();
int64_t nbCommonInnerDims = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do both num and nb stand for number? Could you unify?

while (true) {
int64_t sourceDim = sourceRank - 1 - nbCommonInnerDims;
int64_t resultDim = resultRank - 1 - nbCommonInnerDims;
if (sourceDim < 0 || resultDim < 0)
break;
int64_t dimSize = sourceType.getDimSize(sourceDim);
if (dimSize != resultType.getDimSize(resultDim))
break;
numExtracts /= dimSize;
++nbCommonInnerDims;
}

// Replace with data movement operations:
// x[0,0,0] = y[0,0]
// x[0,0,1] = y[0,1]
// x[0,1,0] = y[0,2]
// etc., incrementing the two index vectors "row-major"
// within the source and result shape.
SmallVector<int64_t> srcIdx(srcRank, 0);
SmallVector<int64_t> resIdx(resRank, 0);
Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
for (int64_t i = 0; i < numElts; i++) {
SmallVector<int64_t> sourceIndex(sourceRank - nbCommonInnerDims, 0);
SmallVector<int64_t> resultIndex(resultRank - nbCommonInnerDims, 0);
Value result = rewriter.create<ub::PoisonOp>(loc, resultType);

for (int64_t i = 0; i < numExtracts; i++) {
if (i != 0) {
incIdx(srcIdx, sourceVectorType);
incIdx(resIdx, resultVectorType);
incIdx(sourceIndex, sourceType.getShape().drop_back(nbCommonInnerDims));
incIdx(resultIndex, resultType.getShape().drop_back(nbCommonInnerDims));
}

Value extract =
rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
rewriter.create<vector::ExtractOp>(loc, op.getSource(), sourceIndex);
result =
rewriter.create<vector::InsertOp>(loc, extract, result, resultIndex);
}
rewriter.replaceOp(op, result);
return success();
Expand Down Expand Up @@ -329,8 +345,8 @@ class ScalableShapeCastOpRewritePattern

// 4. Increment the insert/extract indices, stepping by minExtractionSize
// for the trailing dimensions.
incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize);
incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize);
incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/minExtractionSize);
incIdx(resIdx, resultVectorType.getShape(), /*step=*/minExtractionSize);
}

rewriter.replaceOp(op, result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,59 @@ func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
return %s : vector<f32>
}


// The shapes have 2 inner dimension sizes in common, so the extract result is rank-2.
// CHECK-LABEL: func.func @squeeze_out_prefix_unit_dim(
// CHECK-SAME: %[[ARG0:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
// CHECK: %[[EXTRACTED:.*]] = vector.extract %[[ARG0]][0] :
// CHECK-SAME: vector<2x3xf32> from vector<1x2x3xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This indentation is inconsistent with what's used in @squeeze_out_middle_unit_dim.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a significant test file refactor to make the pre-existing and new tests consistent. The pre-existing test logic is unchanged. I'm happy to postpone refactoring the old tests to make that clearer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy for you to keep all these changes in this PR - this is effectively a proper refactor of the pattern, which was in dire need of some TLC anyway 🙂

// CHECK: return %[[EXTRACTED]] : vector<2x3xf32>
func.func @squeeze_out_prefix_unit_dim(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> {
%s = vector.shape_cast %arg0 : vector<1x2x3xf32> to vector<2x3xf32>
return %s : vector<2x3xf32>
}

// The shapes have 1 inner dimension size in common, so the extract results are rank-1.
// CHECK-LABEL: func.func @squeeze_out_middle_unit_dim(
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<3xf32>
// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0] : vector<3xf32>
// CHECK-SAME: into vector<2x3xf32>
// CHECK: %[[E1:.*]] = vector.extract %[[ARG0]][1, 0] : vector<3xf32>
// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1] : vector<3xf32>
// CHECK-SAME: into vector<2x3xf32>
// CHECK: return %[[I1]] : vector<2x3xf32>
func.func @squeeze_out_middle_unit_dim(%arg0 : vector<2x1x3xf32>) -> vector<2x3xf32> {
%s = vector.shape_cast %arg0 : vector<2x1x3xf32> to vector<2x3xf32>
return %s : vector<2x3xf32>
}

// CHECK-LABEL: func.func @prepend_unit_dim(
// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
// CHECK: %[[UB:.*]] = ub.poison : vector<1x2x3xf32>
// CHECK: %[[INSERTED:.*]] = vector.insert %[[ARG0]], %[[UB]] [0]
// CHECK-SAME: : vector<2x3xf32> into vector<1x2x3xf32>
// CHECK: return %[[INSERTED]] : vector<1x2x3xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This indention is off.

func.func @prepend_unit_dim(%arg0 : vector<2x3xf32>) -> vector<1x2x3xf32> {
%s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<1x2x3xf32>
return %s : vector<1x2x3xf32>
}

// CHECK-LABEL: func.func @insert_middle_unit_dim(
// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xf32>) -> vector<2x1x3xf32> {
// CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0] : vector<3xf32>
// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0, 0] : vector<3xf32>
// CHECK: %[[E1:.*]] = vector.extract %[[ARG0]][1] : vector<3xf32>
// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1, 0] : vector<3xf32>
// CHECK: return %[[I1]] : vector<2x1x3xf32>
func.func @insert_middle_unit_dim(%arg0 : vector<2x3xf32>) -> vector<2x1x3xf32> {
%s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<2x1x3xf32>
return %s : vector<2x1x3xf32>
}


module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
Expand Down