@@ -2351,11 +2351,45 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2351
2351
return success ();
2352
2352
}
2353
2353
2354
+ // / For example,
2355
+ // / ```
2356
+ // / %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
2357
+ // / ```
2358
+ // / becomes
2359
+ // / ```
2360
+ // / %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
2361
+ // / ```
2362
+ struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
2363
+ using OpRewritePattern::OpRewritePattern;
2364
+ LogicalResult matchAndRewrite (vector::ExtractOp extractOp,
2365
+ PatternRewriter &rewriter) const override {
2366
+ VectorType sourceType = extractOp.getSourceVectorType ();
2367
+ VectorType outType = dyn_cast<VectorType>(extractOp.getType ());
2368
+ if (!outType)
2369
+ return failure ();
2370
+
2371
+ // Negative values in `position` indicates poison, cannot convert to
2372
+ // shape_cast
2373
+ if (llvm::any_of (extractOp.getMixedPosition (),
2374
+ [](OpFoldResult v) { return !isConstantIntValue (v, 0 ); }))
2375
+ return failure ();
2376
+
2377
+ if (sourceType.getNumElements () != outType.getNumElements ())
2378
+ return failure ();
2379
+
2380
+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(extractOp, outType,
2381
+ extractOp.getVector ());
2382
+ return success ();
2383
+ }
2384
+ };
2385
+
2354
2386
} // namespace
2355
2387
2356
2388
void ExtractOp::getCanonicalizationPatterns (RewritePatternSet &results,
2357
2389
MLIRContext *context) {
2358
- results.add <ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2390
+ results
2391
+ .add <ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
2392
+ context);
2359
2393
results.add (foldExtractFromShapeCastToShapeCast);
2360
2394
results.add (foldExtractFromFromElements);
2361
2395
}
@@ -2867,13 +2901,40 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
2867
2901
return success ();
2868
2902
}
2869
2903
};
2904
+
2905
+ // / For example,
2906
+ // / ```
2907
+ // / %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
2908
+ // / ```
2909
+ // / becomes
2910
+ // / ```
2911
+ // / %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
2912
+ // / ```
2913
+ struct BroadcastToShapeCast final
2914
+ : public OpRewritePattern<vector::BroadcastOp> {
2915
+ using OpRewritePattern::OpRewritePattern;
2916
+ LogicalResult matchAndRewrite (vector::BroadcastOp broadcast,
2917
+ PatternRewriter &rewriter) const override {
2918
+ auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType ());
2919
+ if (!sourceType) {
2920
+ return rewriter.notifyMatchFailure (
2921
+ broadcast, " source is a scalar, shape_cast doesn't support scalar" );
2922
+ }
2923
+
2924
+ VectorType outType = broadcast.getType ();
2925
+ if (sourceType.getNumElements () != outType.getNumElements ())
2926
+ return failure ();
2927
+
2928
+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(broadcast, outType,
2929
+ broadcast.getSource ());
2930
+ return success ();
2931
+ }
2932
+ };
2870
2933
} // namespace
2871
2934
2872
2935
void BroadcastOp::getCanonicalizationPatterns (RewritePatternSet &results,
2873
2936
MLIRContext *context) {
2874
- // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2875
- // calling `populateCastAwayVectorLeadingOneDimPatterns`
2876
- results.add <BroadcastFolder>(context);
2937
+ results.add <BroadcastFolder, BroadcastToShapeCast>(context);
2877
2938
}
2878
2939
2879
2940
// ===----------------------------------------------------------------------===//
@@ -5816,30 +5877,6 @@ LogicalResult ShapeCastOp::verify() {
5816
5877
return success ();
5817
5878
}
5818
5879
5819
- // / Return true if `transpose` does not permute a pair of non-unit dims.
5820
- // / By `order preserving` we mean that the flattened versions of the input and
5821
- // / output vectors are (numerically) identical. In other words `transpose` is
5822
- // / effectively a shape cast.
5823
- static bool isOrderPreserving (TransposeOp transpose) {
5824
- ArrayRef<int64_t > permutation = transpose.getPermutation ();
5825
- VectorType sourceType = transpose.getSourceVectorType ();
5826
- ArrayRef<int64_t > inShape = sourceType.getShape ();
5827
- ArrayRef<bool > inDimIsScalable = sourceType.getScalableDims ();
5828
- auto isNonScalableUnitDim = [&](int64_t dim) {
5829
- return inShape[dim] == 1 && !inDimIsScalable[dim];
5830
- };
5831
- int64_t current = 0 ;
5832
- for (auto p : permutation) {
5833
- if (!isNonScalableUnitDim (p)) {
5834
- if (p < current) {
5835
- return false ;
5836
- }
5837
- current = p;
5838
- }
5839
- }
5840
- return true ;
5841
- }
5842
-
5843
5880
OpFoldResult ShapeCastOp::fold (FoldAdaptor adaptor) {
5844
5881
5845
5882
VectorType resultType = getType ();
@@ -5854,22 +5891,6 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5854
5891
return getResult ();
5855
5892
}
5856
5893
5857
- // shape_cast(transpose(x)) -> shape_cast(x)
5858
- if (auto transpose = getSource ().getDefiningOp <TransposeOp>()) {
5859
- if (isOrderPreserving (transpose)) {
5860
- setOperand (transpose.getVector ());
5861
- return getResult ();
5862
- }
5863
- return {};
5864
- }
5865
-
5866
- // Y = shape_cast(broadcast(X))
5867
- // -> X, if X and Y have same type
5868
- if (auto bcastOp = getSource ().getDefiningOp <BroadcastOp>()) {
5869
- if (bcastOp.getSourceType () == resultType)
5870
- return bcastOp.getSource ();
5871
- }
5872
-
5873
5894
// shape_cast(constant) -> constant
5874
5895
if (auto splatAttr =
5875
5896
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource ()))
@@ -5991,10 +6012,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
5991
6012
}
5992
6013
};
5993
6014
5994
- // / Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
5995
- // / i) Y = ShapeCast(X), or
5996
- // / ii) Y = Broadcast(X)
5997
- // / If both (i) and (ii) are possible, (i) is chosen.
6015
+ // / Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
5998
6016
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
5999
6017
public:
6000
6018
using OpRewritePattern::OpRewritePattern;
@@ -6009,22 +6027,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
6009
6027
auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType ());
6010
6028
bool srcIsScalar = !srcVectorType;
6011
6029
6012
- // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
6013
- // Example:
6014
- // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
6015
- // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
6016
- // to
6017
- // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
6018
- if (srcVectorType) {
6019
- if (srcVectorType.getNumElements () ==
6020
- shapeCastOp.getResultVectorType ().getNumElements ()) {
6021
- rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(
6022
- shapeCastOp, shapeCastOp.getResultVectorType (),
6023
- broadcastOp.getSource ());
6024
- return success ();
6025
- }
6026
- }
6027
-
6028
6030
// Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
6029
6031
// Example
6030
6032
// %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
@@ -6225,21 +6227,6 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6225
6227
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector ()))
6226
6228
return ub::PoisonAttr::get (getContext ());
6227
6229
6228
- // Eliminate identity transposes, and more generally any transposes that
6229
- // preserves the shape without permuting elements.
6230
- //
6231
- // Examples of what to fold:
6232
- // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
6233
- // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
6234
- // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
6235
- //
6236
- // Example of what NOT to fold:
6237
- // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
6238
- //
6239
- if (getSourceVectorType () == getResultVectorType () &&
6240
- isOrderPreserving (*this ))
6241
- return getVector ();
6242
-
6243
6230
return {};
6244
6231
}
6245
6232
@@ -6359,32 +6346,6 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
6359
6346
}
6360
6347
};
6361
6348
6362
- // / Folds transpose(shape_cast) into a new shape_cast.
6363
- class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
6364
- public:
6365
- using OpRewritePattern::OpRewritePattern;
6366
-
6367
- LogicalResult matchAndRewrite (TransposeOp transposeOp,
6368
- PatternRewriter &rewriter) const override {
6369
- auto shapeCastOp =
6370
- transposeOp.getVector ().getDefiningOp <vector::ShapeCastOp>();
6371
- if (!shapeCastOp)
6372
- return failure ();
6373
- if (!isOrderPreserving (transposeOp))
6374
- return failure ();
6375
-
6376
- VectorType resultType = transposeOp.getType ();
6377
-
6378
- // We don't need to check isValidShapeCast at this point, because it is
6379
- // guaranteed that merging the transpose into the the shape_cast is a valid
6380
- // shape_cast, because the transpose just inserts/removes ones.
6381
-
6382
- rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(transposeOp, resultType,
6383
- shapeCastOp.getSource ());
6384
- return success ();
6385
- }
6386
- };
6387
-
6388
6349
// / Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
6389
6350
// / 'order preserving', where 'order preserving' means the flattened
6390
6351
// / inputs and outputs of the transpose have identical (numerical) values.
@@ -6480,12 +6441,73 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
6480
6441
}
6481
6442
};
6482
6443
6444
+ // / Return true if `transpose` does not permute a pair of non-unit dims.
6445
+ // / By `order preserving` we mean that the flattened versions of the input and
6446
+ // / output vectors are (numerically) identical. In other words `transpose` is
6447
+ // / effectively a shape cast.
6448
+ static bool isOrderPreserving (TransposeOp transpose) {
6449
+ ArrayRef<int64_t > permutation = transpose.getPermutation ();
6450
+ VectorType sourceType = transpose.getSourceVectorType ();
6451
+ ArrayRef<int64_t > inShape = sourceType.getShape ();
6452
+ ArrayRef<bool > inDimIsScalable = sourceType.getScalableDims ();
6453
+ auto isNonScalableUnitDim = [&](int64_t dim) {
6454
+ return inShape[dim] == 1 && !inDimIsScalable[dim];
6455
+ };
6456
+ int64_t current = 0 ;
6457
+ for (auto p : permutation) {
6458
+ if (!isNonScalableUnitDim (p)) {
6459
+ if (p < current) {
6460
+ return false ;
6461
+ }
6462
+ current = p;
6463
+ }
6464
+ }
6465
+ return true ;
6466
+ }
6467
+
6468
+ // / For example,
6469
+ // / ```
6470
+ // / %0 = vector.transpose %arg0, [0, 2, 1] :
6471
+ // / vector<2x1x2xf32> to vector<2x2x1xf32>
6472
+ // / ```
6473
+ // / becomes
6474
+ // / ```
6475
+ // / %0 = vector.shape_cast %arg0 :
6476
+ // / vector<2x1x2xf32> to vector<2x2x1xf32>
6477
+ // / ```
6478
+ struct TransposeToShapeCast final
6479
+ : public OpRewritePattern<vector::TransposeOp> {
6480
+ using OpRewritePattern::OpRewritePattern;
6481
+ LogicalResult matchAndRewrite (vector::TransposeOp transpose,
6482
+ PatternRewriter &rewriter) const override {
6483
+
6484
+ // This folder does
6485
+ // shape_cast(transpose) -> shape_cast
6486
+ // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
6487
+ // shape_cast -> shape_cast(transpose)
6488
+ // i.e. the complete opposite. When paired, these 2 patterns can cause
6489
+ // infinite cycles in pattern rewriting.
6490
+ // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
6491
+ // vectors, so by disabling this folder for scalable vectors the
6492
+ // cycle is avoided.
6493
+ // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
6494
+ // still needed. If it's not, then we can fold here.
6495
+ if (!isOrderPreserving (transpose) || transpose.getType ().isScalable ()) {
6496
+ return rewriter.notifyMatchFailure (
6497
+ transpose, " not order preserving, so not semantically a 'copy'" );
6498
+ }
6499
+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(
6500
+ transpose, transpose.getType (), transpose.getVector ());
6501
+ return success ();
6502
+ }
6503
+ };
6504
+
6483
6505
} // namespace
6484
6506
6485
6507
void vector::TransposeOp::getCanonicalizationPatterns (
6486
6508
RewritePatternSet &results, MLIRContext *context) {
6487
- results.add <FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder ,
6488
- FoldTransposeSplat, FoldTransposeBroadcast >(context);
6509
+ results.add <FoldTransposeCreateMask, TransposeFolder, FoldTransposeSplat ,
6510
+ FoldTransposeBroadcast, TransposeToShapeCast >(context);
6489
6511
}
6490
6512
6491
6513
// ===----------------------------------------------------------------------===//
0 commit comments