Skip to content

Commit 29d41d8

Browse files
committed
use shape_cast as canonical type for extract broadcast and transpose
1 parent b3ed428 commit 29d41d8

11 files changed

+378
-437
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 129 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -2347,11 +2347,45 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
23472347
return success();
23482348
}
23492349

2350+
/// For example,
2351+
/// ```
2352+
/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
2353+
/// ```
2354+
/// becomes
2355+
/// ```
2356+
/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
2357+
/// ```
2358+
struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
2359+
using OpRewritePattern::OpRewritePattern;
2360+
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2361+
PatternRewriter &rewriter) const override {
2362+
VectorType sourceType = extractOp.getSourceVectorType();
2363+
VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2364+
if (!outType)
2365+
return failure();
2366+
2367+
// Negative values in `position` indicates poison, cannot convert to
2368+
// shape_cast
2369+
if (llvm::any_of(extractOp.getMixedPosition(),
2370+
[](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
2371+
return failure();
2372+
2373+
if (sourceType.getNumElements() != outType.getNumElements())
2374+
return failure();
2375+
2376+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
2377+
extractOp.getVector());
2378+
return success();
2379+
}
2380+
};
2381+
23502382
} // namespace
23512383

23522384
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
23532385
MLIRContext *context) {
2354-
results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2386+
results
2387+
.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
2388+
context);
23552389
results.add(foldExtractFromShapeCastToShapeCast);
23562390
results.add(foldExtractFromFromElements);
23572391
}
@@ -2774,13 +2808,40 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
27742808
return success();
27752809
}
27762810
};
2811+
2812+
/// For example,
2813+
/// ```
2814+
/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
2815+
/// ```
2816+
/// becomes
2817+
/// ```
2818+
/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
2819+
/// ```
2820+
struct BroadcastToShapeCast final
2821+
: public OpRewritePattern<vector::BroadcastOp> {
2822+
using OpRewritePattern::OpRewritePattern;
2823+
LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
2824+
PatternRewriter &rewriter) const override {
2825+
auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
2826+
if (!sourceType) {
2827+
return rewriter.notifyMatchFailure(
2828+
broadcast, "source is a scalar, shape_cast doesn't support scalar");
2829+
}
2830+
2831+
VectorType outType = broadcast.getType();
2832+
if (sourceType.getNumElements() != outType.getNumElements())
2833+
return failure();
2834+
2835+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
2836+
broadcast.getSource());
2837+
return success();
2838+
}
2839+
};
27772840
} // namespace
27782841

27792842
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
27802843
MLIRContext *context) {
2781-
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2782-
// calling `populateCastAwayVectorLeadingOneDimPatterns`
2783-
results.add<BroadcastFolder>(context);
2844+
results.add<BroadcastFolder, BroadcastToShapeCast>(context);
27842845
}
27852846

27862847
//===----------------------------------------------------------------------===//
@@ -5698,30 +5759,6 @@ LogicalResult ShapeCastOp::verify() {
56985759
return success();
56995760
}
57005761

5701-
/// Return true if `transpose` does not permute a pair of non-unit dims.
5702-
/// By `order preserving` we mean that the flattened versions of the input and
5703-
/// output vectors are (numerically) identical. In other words `transpose` is
5704-
/// effectively a shape cast.
5705-
static bool isOrderPreserving(TransposeOp transpose) {
5706-
ArrayRef<int64_t> permutation = transpose.getPermutation();
5707-
VectorType sourceType = transpose.getSourceVectorType();
5708-
ArrayRef<int64_t> inShape = sourceType.getShape();
5709-
ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
5710-
auto isNonScalableUnitDim = [&](int64_t dim) {
5711-
return inShape[dim] == 1 && !inDimIsScalable[dim];
5712-
};
5713-
int64_t current = 0;
5714-
for (auto p : permutation) {
5715-
if (!isNonScalableUnitDim(p)) {
5716-
if (p < current) {
5717-
return false;
5718-
}
5719-
current = p;
5720-
}
5721-
}
5722-
return true;
5723-
}
5724-
57255762
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
57265763

57275764
VectorType resultType = getType();
@@ -5736,33 +5773,6 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
57365773
return getResult();
57375774
}
57385775

5739-
// shape_cast(transpose(x)) -> shape_cast(x)
5740-
if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
5741-
// This folder does
5742-
// shape_cast(transpose) -> shape_cast
5743-
// But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
5744-
// shape_cast -> shape_cast(transpose)
5745-
// i.e. the complete opposite. When paired, these 2 patterns can cause
5746-
// infinite cycles in pattern rewriting.
5747-
// ConvertIllegalShapeCastOpsToTransposes only matches on scalable
5748-
// vectors, so by disabling this folder for scalable vectors the
5749-
// cycle is avoided.
5750-
// TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
5751-
// still needed. If it's not, then we can fold here.
5752-
if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) {
5753-
setOperand(transpose.getVector());
5754-
return getResult();
5755-
}
5756-
return {};
5757-
}
5758-
5759-
// Y = shape_cast(broadcast(X))
5760-
// -> X, if X and Y have same type
5761-
if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5762-
if (bcastOp.getSourceType() == resultType)
5763-
return bcastOp.getSource();
5764-
}
5765-
57665776
// shape_cast(constant) -> constant
57675777
if (auto splatAttr =
57685778
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
@@ -5884,10 +5894,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
58845894
}
58855895
};
58865896

5887-
/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
5888-
/// i) Y = ShapeCast(X), or
5889-
/// ii) Y = Broadcast(X)
5890-
/// If both (i) and (ii) are possible, (i) is chosen.
5897+
/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
58915898
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
58925899
public:
58935900
using OpRewritePattern::OpRewritePattern;
@@ -5902,22 +5909,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
59025909
auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
59035910
bool srcIsScalar = !srcVectorType;
59045911

5905-
// Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
5906-
// Example:
5907-
// %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
5908-
// %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
5909-
// to
5910-
// %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
5911-
if (srcVectorType) {
5912-
if (srcVectorType.getNumElements() ==
5913-
shapeCastOp.getResultVectorType().getNumElements()) {
5914-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
5915-
shapeCastOp, shapeCastOp.getResultVectorType(),
5916-
broadcastOp.getSource());
5917-
return success();
5918-
}
5919-
}
5920-
59215912
// Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
59225913
// Example
59235914
// %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
@@ -6118,21 +6109,6 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
61186109
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
61196110
return ub::PoisonAttr::get(getContext());
61206111

6121-
// Eliminate identity transposes, and more generally any transposes that
6122-
// preserves the shape without permuting elements.
6123-
//
6124-
// Examples of what to fold:
6125-
// %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
6126-
// %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
6127-
// %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
6128-
//
6129-
// Example of what NOT to fold:
6130-
// %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
6131-
//
6132-
if (getSourceVectorType() == getResultVectorType() &&
6133-
isOrderPreserving(*this))
6134-
return getVector();
6135-
61366112
return {};
61376113
}
61386114

@@ -6252,32 +6228,6 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
62526228
}
62536229
};
62546230

6255-
/// Folds transpose(shape_cast) into a new shape_cast.
6256-
class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
6257-
public:
6258-
using OpRewritePattern::OpRewritePattern;
6259-
6260-
LogicalResult matchAndRewrite(TransposeOp transposeOp,
6261-
PatternRewriter &rewriter) const override {
6262-
auto shapeCastOp =
6263-
transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6264-
if (!shapeCastOp)
6265-
return failure();
6266-
if (!isOrderPreserving(transposeOp))
6267-
return failure();
6268-
6269-
VectorType resultType = transposeOp.getType();
6270-
6271-
// We don't need to check isValidShapeCast at this point, because it is
6272-
// guaranteed that merging the transpose into the the shape_cast is a valid
6273-
// shape_cast, because the transpose just inserts/removes ones.
6274-
6275-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
6276-
shapeCastOp.getSource());
6277-
return success();
6278-
}
6279-
};
6280-
62816231
/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
62826232
/// 'order preserving', where 'order preserving' means the flattened
62836233
/// inputs and outputs of the transpose have identical (numerical) values.
@@ -6373,12 +6323,73 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
63736323
}
63746324
};
63756325

6326+
/// Return true if `transpose` does not permute a pair of non-unit dims.
6327+
/// By `order preserving` we mean that the flattened versions of the input and
6328+
/// output vectors are (numerically) identical. In other words `transpose` is
6329+
/// effectively a shape cast.
6330+
static bool isOrderPreserving(TransposeOp transpose) {
6331+
ArrayRef<int64_t> permutation = transpose.getPermutation();
6332+
VectorType sourceType = transpose.getSourceVectorType();
6333+
ArrayRef<int64_t> inShape = sourceType.getShape();
6334+
ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6335+
auto isNonScalableUnitDim = [&](int64_t dim) {
6336+
return inShape[dim] == 1 && !inDimIsScalable[dim];
6337+
};
6338+
int64_t current = 0;
6339+
for (auto p : permutation) {
6340+
if (!isNonScalableUnitDim(p)) {
6341+
if (p < current) {
6342+
return false;
6343+
}
6344+
current = p;
6345+
}
6346+
}
6347+
return true;
6348+
}
6349+
6350+
/// For example,
6351+
/// ```
6352+
/// %0 = vector.transpose %arg0, [0, 2, 1] :
6353+
/// vector<2x1x2xf32> to vector<2x2x1xf32>
6354+
/// ```
6355+
/// becomes
6356+
/// ```
6357+
/// %0 = vector.shape_cast %arg0 :
6358+
/// vector<2x1x2xf32> to vector<2x2x1xf32>
6359+
/// ```
6360+
struct TransposeToShapeCast final
6361+
: public OpRewritePattern<vector::TransposeOp> {
6362+
using OpRewritePattern::OpRewritePattern;
6363+
LogicalResult matchAndRewrite(vector::TransposeOp transpose,
6364+
PatternRewriter &rewriter) const override {
6365+
6366+
// This folder does
6367+
// shape_cast(transpose) -> shape_cast
6368+
// But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
6369+
// shape_cast -> shape_cast(transpose)
6370+
// i.e. the complete opposite. When paired, these 2 patterns can cause
6371+
// infinite cycles in pattern rewriting.
6372+
// ConvertIllegalShapeCastOpsToTransposes only matches on scalable
6373+
// vectors, so by disabling this folder for scalable vectors the
6374+
// cycle is avoided.
6375+
// TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
6376+
// still needed. If it's not, then we can fold here.
6377+
if (!isOrderPreserving(transpose) || transpose.getType().isScalable()) {
6378+
return rewriter.notifyMatchFailure(
6379+
transpose, "not order preserving, so not semantically a 'copy'");
6380+
}
6381+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
6382+
transpose, transpose.getType(), transpose.getVector());
6383+
return success();
6384+
}
6385+
};
6386+
63766387
} // namespace
63776388

63786389
void vector::TransposeOp::getCanonicalizationPatterns(
63796390
RewritePatternSet &results, MLIRContext *context) {
6380-
results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
6381-
FoldTransposeSplat, FoldTransposeBroadcast>(context);
6391+
results.add<FoldTransposeCreateMask, TransposeFolder, FoldTransposeSplat,
6392+
FoldTransposeBroadcast, TransposeToShapeCast>(context);
63826393
}
63836394

63846395
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp

Lines changed: 1 addition & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -382,63 +382,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
382382
vector::VectorTransposeLowering vectorTransposeLowering;
383383
};
384384

385-
/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
386-
/// to 2D vectors with at least one unit dim. For example:
387-
///
388-
/// Replace:
389-
/// vector.transpose %0, [1, 0] : vector<4x1xi32>> to
390-
/// vector<1x4xi32>
391-
/// with:
392-
/// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
393-
///
394-
/// Source with leading unit dim (inverse) is also replaced. Unit dim must
395-
/// be fixed. Non-unit dim can be scalable.
396-
///
397-
/// TODO: This pattern was introduced specifically to help lower scalable
398-
/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
399-
/// to cancel out) would be preferable:
400-
///
401-
/// BEFORE:
402-
/// %0 = some_op
403-
/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
404-
/// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
405-
/// AFTER:
406-
/// %0 = some_op
407-
/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
408-
///
409-
/// Given the context above, we may want to consider (re-)moving this pattern
410-
/// at some later time. I am leaving it for now in case there are other users
411-
/// that I am not aware of.
412-
class Transpose2DWithUnitDimToShapeCast
413-
: public OpRewritePattern<vector::TransposeOp> {
414-
public:
415-
using OpRewritePattern::OpRewritePattern;
416-
417-
Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
418-
PatternBenefit benefit = 1)
419-
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}
420-
421-
LogicalResult matchAndRewrite(vector::TransposeOp op,
422-
PatternRewriter &rewriter) const override {
423-
Value input = op.getVector();
424-
VectorType resType = op.getResultVectorType();
425-
426-
// Set up convenience transposition table.
427-
ArrayRef<int64_t> transp = op.getPermutation();
428-
429-
if (resType.getRank() == 2 &&
430-
((resType.getShape().front() == 1 &&
431-
!resType.getScalableDims().front()) ||
432-
(resType.getShape().back() == 1 &&
433-
!resType.getScalableDims().back())) &&
434-
transp == ArrayRef<int64_t>({1, 0})) {
435-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
436-
return success();
437-
}
438-
439-
return failure();
440-
}
441-
};
442385

443386
/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
444387
/// If the strategy is Shuffle1D, it will be lowered to:
@@ -511,8 +454,7 @@ class TransposeOp2DToShuffleLowering
511454
void mlir::vector::populateVectorTransposeLoweringPatterns(
512455
RewritePatternSet &patterns,
513456
VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
514-
patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
515-
benefit);
457+
BroadcastOp::getCanonicalizationPatterns(patterns, patterns.getContext());
516458
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
517459
vectorTransposeLowering, patterns.getContext(), benefit);
518460
}

0 commit comments

Comments
 (0)