Skip to content

Commit cce1483

Browse files
committed
use shape_cast as canonical type for extract broadcast and transpose
1 parent 00f6d6a commit cce1483

11 files changed

+379
-316
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 129 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -2351,11 +2351,45 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
23512351
return success();
23522352
}
23532353

2354+
/// For example,
2355+
/// ```
2356+
/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
2357+
/// ```
2358+
/// becomes
2359+
/// ```
2360+
/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
2361+
/// ```
2362+
struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
2363+
using OpRewritePattern::OpRewritePattern;
2364+
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2365+
PatternRewriter &rewriter) const override {
2366+
VectorType sourceType = extractOp.getSourceVectorType();
2367+
VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2368+
if (!outType)
2369+
return failure();
2370+
2371+
// Negative values in `position` indicates poison, cannot convert to
2372+
// shape_cast
2373+
if (llvm::any_of(extractOp.getMixedPosition(),
2374+
[](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
2375+
return failure();
2376+
2377+
if (sourceType.getNumElements() != outType.getNumElements())
2378+
return failure();
2379+
2380+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
2381+
extractOp.getVector());
2382+
return success();
2383+
}
2384+
};
2385+
23542386
} // namespace
23552387

23562388
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
23572389
MLIRContext *context) {
2358-
results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2390+
results
2391+
.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
2392+
context);
23592393
results.add(foldExtractFromShapeCastToShapeCast);
23602394
results.add(foldExtractFromFromElements);
23612395
}
@@ -2867,13 +2901,40 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
28672901
return success();
28682902
}
28692903
};
2904+
2905+
/// For example,
2906+
/// ```
2907+
/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
2908+
/// ```
2909+
/// becomes
2910+
/// ```
2911+
/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
2912+
/// ```
2913+
struct BroadcastToShapeCast final
2914+
: public OpRewritePattern<vector::BroadcastOp> {
2915+
using OpRewritePattern::OpRewritePattern;
2916+
LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
2917+
PatternRewriter &rewriter) const override {
2918+
auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
2919+
if (!sourceType) {
2920+
return rewriter.notifyMatchFailure(
2921+
broadcast, "source is a scalar, shape_cast doesn't support scalar");
2922+
}
2923+
2924+
VectorType outType = broadcast.getType();
2925+
if (sourceType.getNumElements() != outType.getNumElements())
2926+
return failure();
2927+
2928+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
2929+
broadcast.getSource());
2930+
return success();
2931+
}
2932+
};
28702933
} // namespace
28712934

28722935
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
28732936
MLIRContext *context) {
2874-
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2875-
// calling `populateCastAwayVectorLeadingOneDimPatterns`
2876-
results.add<BroadcastFolder>(context);
2937+
results.add<BroadcastFolder, BroadcastToShapeCast>(context);
28772938
}
28782939

28792940
//===----------------------------------------------------------------------===//
@@ -5816,30 +5877,6 @@ LogicalResult ShapeCastOp::verify() {
58165877
return success();
58175878
}
58185879

5819-
/// Return true if `transpose` does not permute a pair of non-unit dims.
5820-
/// By `order preserving` we mean that the flattened versions of the input and
5821-
/// output vectors are (numerically) identical. In other words `transpose` is
5822-
/// effectively a shape cast.
5823-
static bool isOrderPreserving(TransposeOp transpose) {
5824-
ArrayRef<int64_t> permutation = transpose.getPermutation();
5825-
VectorType sourceType = transpose.getSourceVectorType();
5826-
ArrayRef<int64_t> inShape = sourceType.getShape();
5827-
ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
5828-
auto isNonScalableUnitDim = [&](int64_t dim) {
5829-
return inShape[dim] == 1 && !inDimIsScalable[dim];
5830-
};
5831-
int64_t current = 0;
5832-
for (auto p : permutation) {
5833-
if (!isNonScalableUnitDim(p)) {
5834-
if (p < current) {
5835-
return false;
5836-
}
5837-
current = p;
5838-
}
5839-
}
5840-
return true;
5841-
}
5842-
58435880
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
58445881

58455882
VectorType resultType = getType();
@@ -5854,22 +5891,6 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
58545891
return getResult();
58555892
}
58565893

5857-
// shape_cast(transpose(x)) -> shape_cast(x)
5858-
if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
5859-
if (isOrderPreserving(transpose)) {
5860-
setOperand(transpose.getVector());
5861-
return getResult();
5862-
}
5863-
return {};
5864-
}
5865-
5866-
// Y = shape_cast(broadcast(X))
5867-
// -> X, if X and Y have same type
5868-
if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5869-
if (bcastOp.getSourceType() == resultType)
5870-
return bcastOp.getSource();
5871-
}
5872-
58735894
// shape_cast(constant) -> constant
58745895
if (auto splatAttr =
58755896
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
@@ -5991,10 +6012,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
59916012
}
59926013
};
59936014

5994-
/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
5995-
/// i) Y = ShapeCast(X), or
5996-
/// ii) Y = Broadcast(X)
5997-
/// If both (i) and (ii) are possible, (i) is chosen.
6015+
/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
59986016
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
59996017
public:
60006018
using OpRewritePattern::OpRewritePattern;
@@ -6009,22 +6027,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
60096027
auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
60106028
bool srcIsScalar = !srcVectorType;
60116029

6012-
// Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
6013-
// Example:
6014-
// %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
6015-
// %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
6016-
// to
6017-
// %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
6018-
if (srcVectorType) {
6019-
if (srcVectorType.getNumElements() ==
6020-
shapeCastOp.getResultVectorType().getNumElements()) {
6021-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
6022-
shapeCastOp, shapeCastOp.getResultVectorType(),
6023-
broadcastOp.getSource());
6024-
return success();
6025-
}
6026-
}
6027-
60286030
// Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
60296031
// Example
60306032
// %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
@@ -6225,21 +6227,6 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
62256227
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
62266228
return ub::PoisonAttr::get(getContext());
62276229

6228-
// Eliminate identity transposes, and more generally any transposes that
6229-
// preserves the shape without permuting elements.
6230-
//
6231-
// Examples of what to fold:
6232-
// %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
6233-
// %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
6234-
// %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
6235-
//
6236-
// Example of what NOT to fold:
6237-
// %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
6238-
//
6239-
if (getSourceVectorType() == getResultVectorType() &&
6240-
isOrderPreserving(*this))
6241-
return getVector();
6242-
62436230
return {};
62446231
}
62456232

@@ -6359,32 +6346,6 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
63596346
}
63606347
};
63616348

6362-
/// Folds transpose(shape_cast) into a new shape_cast.
6363-
class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
6364-
public:
6365-
using OpRewritePattern::OpRewritePattern;
6366-
6367-
LogicalResult matchAndRewrite(TransposeOp transposeOp,
6368-
PatternRewriter &rewriter) const override {
6369-
auto shapeCastOp =
6370-
transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6371-
if (!shapeCastOp)
6372-
return failure();
6373-
if (!isOrderPreserving(transposeOp))
6374-
return failure();
6375-
6376-
VectorType resultType = transposeOp.getType();
6377-
6378-
// We don't need to check isValidShapeCast at this point, because it is
6379-
// guaranteed that merging the transpose into the the shape_cast is a valid
6380-
// shape_cast, because the transpose just inserts/removes ones.
6381-
6382-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
6383-
shapeCastOp.getSource());
6384-
return success();
6385-
}
6386-
};
6387-
63886349
/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
63896350
/// 'order preserving', where 'order preserving' means the flattened
63906351
/// inputs and outputs of the transpose have identical (numerical) values.
@@ -6480,12 +6441,73 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
64806441
}
64816442
};
64826443

6444+
/// Return true if `transpose` does not permute a pair of non-unit dims.
6445+
/// By `order preserving` we mean that the flattened versions of the input and
6446+
/// output vectors are (numerically) identical. In other words `transpose` is
6447+
/// effectively a shape cast.
6448+
static bool isOrderPreserving(TransposeOp transpose) {
6449+
ArrayRef<int64_t> permutation = transpose.getPermutation();
6450+
VectorType sourceType = transpose.getSourceVectorType();
6451+
ArrayRef<int64_t> inShape = sourceType.getShape();
6452+
ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6453+
auto isNonScalableUnitDim = [&](int64_t dim) {
6454+
return inShape[dim] == 1 && !inDimIsScalable[dim];
6455+
};
6456+
int64_t current = 0;
6457+
for (auto p : permutation) {
6458+
if (!isNonScalableUnitDim(p)) {
6459+
if (p < current) {
6460+
return false;
6461+
}
6462+
current = p;
6463+
}
6464+
}
6465+
return true;
6466+
}
6467+
6468+
/// For example,
6469+
/// ```
6470+
/// %0 = vector.transpose %arg0, [0, 2, 1] :
6471+
/// vector<2x1x2xf32> to vector<2x2x1xf32>
6472+
/// ```
6473+
/// becomes
6474+
/// ```
6475+
/// %0 = vector.shape_cast %arg0 :
6476+
/// vector<2x1x2xf32> to vector<2x2x1xf32>
6477+
/// ```
6478+
struct TransposeToShapeCast final
6479+
: public OpRewritePattern<vector::TransposeOp> {
6480+
using OpRewritePattern::OpRewritePattern;
6481+
LogicalResult matchAndRewrite(vector::TransposeOp transpose,
6482+
PatternRewriter &rewriter) const override {
6483+
6484+
// This folder does
6485+
// shape_cast(transpose) -> shape_cast
6486+
// But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
6487+
// shape_cast -> shape_cast(transpose)
6488+
// i.e. the complete opposite. When paired, these 2 patterns can cause
6489+
// infinite cycles in pattern rewriting.
6490+
// ConvertIllegalShapeCastOpsToTransposes only matches on scalable
6491+
// vectors, so by disabling this folder for scalable vectors the
6492+
// cycle is avoided.
6493+
// TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
6494+
// still needed. If it's not, then we can fold here.
6495+
if (!isOrderPreserving(transpose) || transpose.getType().isScalable()) {
6496+
return rewriter.notifyMatchFailure(
6497+
transpose, "not order preserving, so not semantically a 'copy'");
6498+
}
6499+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
6500+
transpose, transpose.getType(), transpose.getVector());
6501+
return success();
6502+
}
6503+
};
6504+
64836505
} // namespace
64846506

64856507
void vector::TransposeOp::getCanonicalizationPatterns(
64866508
RewritePatternSet &results, MLIRContext *context) {
6487-
results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
6488-
FoldTransposeSplat, FoldTransposeBroadcast>(context);
6509+
results.add<FoldTransposeCreateMask, TransposeFolder, FoldTransposeSplat,
6510+
FoldTransposeBroadcast, TransposeToShapeCast>(context);
64896511
}
64906512

64916513
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp

Lines changed: 1 addition & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -382,63 +382,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
382382
vector::VectorTransposeLowering vectorTransposeLowering;
383383
};
384384

385-
/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
386-
/// to 2D vectors with at least one unit dim. For example:
387-
///
388-
/// Replace:
389-
/// vector.transpose %0, [1, 0] : vector<4x1xi32>> to
390-
/// vector<1x4xi32>
391-
/// with:
392-
/// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
393-
///
394-
/// Source with leading unit dim (inverse) is also replaced. Unit dim must
395-
/// be fixed. Non-unit dim can be scalable.
396-
///
397-
/// TODO: This pattern was introduced specifically to help lower scalable
398-
/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
399-
/// to cancel out) would be preferable:
400-
///
401-
/// BEFORE:
402-
/// %0 = some_op
403-
/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
404-
/// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
405-
/// AFTER:
406-
/// %0 = some_op
407-
/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
408-
///
409-
/// Given the context above, we may want to consider (re-)moving this pattern
410-
/// at some later time. I am leaving it for now in case there are other users
411-
/// that I am not aware of.
412-
class Transpose2DWithUnitDimToShapeCast
413-
: public OpRewritePattern<vector::TransposeOp> {
414-
public:
415-
using OpRewritePattern::OpRewritePattern;
416-
417-
Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
418-
PatternBenefit benefit = 1)
419-
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}
420-
421-
LogicalResult matchAndRewrite(vector::TransposeOp op,
422-
PatternRewriter &rewriter) const override {
423-
Value input = op.getVector();
424-
VectorType resType = op.getResultVectorType();
425-
426-
// Set up convenience transposition table.
427-
ArrayRef<int64_t> transp = op.getPermutation();
428-
429-
if (resType.getRank() == 2 &&
430-
((resType.getShape().front() == 1 &&
431-
!resType.getScalableDims().front()) ||
432-
(resType.getShape().back() == 1 &&
433-
!resType.getScalableDims().back())) &&
434-
transp == ArrayRef<int64_t>({1, 0})) {
435-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
436-
return success();
437-
}
438-
439-
return failure();
440-
}
441-
};
442385

443386
/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
444387
/// If the strategy is Shuffle1D, it will be lowered to:
@@ -511,8 +454,7 @@ class TransposeOp2DToShuffleLowering
511454
void mlir::vector::populateVectorTransposeLoweringPatterns(
512455
RewritePatternSet &patterns,
513456
VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
514-
patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
515-
benefit);
457+
BroadcastOp::getCanonicalizationPatterns(patterns, patterns.getContext());
516458
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
517459
vectorTransposeLowering, patterns.getContext(), benefit);
518460
}

0 commit comments

Comments
 (0)