@@ -2347,11 +2347,45 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2347
2347
return success ();
2348
2348
}
2349
2349
2350
+ // / For example,
2351
+ // / ```
2352
+ // / %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
2353
+ // / ```
2354
+ // / becomes
2355
+ // / ```
2356
+ // / %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
2357
+ // / ```
2358
+ struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
2359
+ using OpRewritePattern::OpRewritePattern;
2360
+ LogicalResult matchAndRewrite (vector::ExtractOp extractOp,
2361
+ PatternRewriter &rewriter) const override {
2362
+ VectorType sourceType = extractOp.getSourceVectorType ();
2363
+ VectorType outType = dyn_cast<VectorType>(extractOp.getType ());
2364
+ if (!outType)
2365
+ return failure ();
2366
+
2367
+ // Negative values in `position` indicates poison, cannot convert to
2368
+ // shape_cast
2369
+ if (llvm::any_of (extractOp.getMixedPosition (),
2370
+ [](OpFoldResult v) { return !isConstantIntValue (v, 0 ); }))
2371
+ return failure ();
2372
+
2373
+ if (sourceType.getNumElements () != outType.getNumElements ())
2374
+ return failure ();
2375
+
2376
+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(extractOp, outType,
2377
+ extractOp.getVector ());
2378
+ return success ();
2379
+ }
2380
+ };
2381
+
2350
2382
} // namespace
2351
2383
2352
2384
void ExtractOp::getCanonicalizationPatterns (RewritePatternSet &results,
2353
2385
MLIRContext *context) {
2354
- results.add <ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2386
+ results
2387
+ .add <ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
2388
+ context);
2355
2389
results.add (foldExtractFromShapeCastToShapeCast);
2356
2390
results.add (foldExtractFromFromElements);
2357
2391
}
@@ -2774,13 +2808,40 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
2774
2808
return success ();
2775
2809
}
2776
2810
};
2811
+
2812
+ // / For example,
2813
+ // / ```
2814
+ // / %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
2815
+ // / ```
2816
+ // / becomes
2817
+ // / ```
2818
+ // / %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
2819
+ // / ```
2820
+ struct BroadcastToShapeCast final
2821
+ : public OpRewritePattern<vector::BroadcastOp> {
2822
+ using OpRewritePattern::OpRewritePattern;
2823
+ LogicalResult matchAndRewrite (vector::BroadcastOp broadcast,
2824
+ PatternRewriter &rewriter) const override {
2825
+ auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType ());
2826
+ if (!sourceType) {
2827
+ return rewriter.notifyMatchFailure (
2828
+ broadcast, " source is a scalar, shape_cast doesn't support scalar" );
2829
+ }
2830
+
2831
+ VectorType outType = broadcast.getType ();
2832
+ if (sourceType.getNumElements () != outType.getNumElements ())
2833
+ return failure ();
2834
+
2835
+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(broadcast, outType,
2836
+ broadcast.getSource ());
2837
+ return success ();
2838
+ }
2839
+ };
2777
2840
} // namespace
2778
2841
2779
2842
void BroadcastOp::getCanonicalizationPatterns (RewritePatternSet &results,
2780
2843
MLIRContext *context) {
2781
- // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2782
- // calling `populateCastAwayVectorLeadingOneDimPatterns`
2783
- results.add <BroadcastFolder>(context);
2844
+ results.add <BroadcastFolder, BroadcastToShapeCast>(context);
2784
2845
}
2785
2846
2786
2847
// ===----------------------------------------------------------------------===//
@@ -5698,30 +5759,6 @@ LogicalResult ShapeCastOp::verify() {
5698
5759
return success ();
5699
5760
}
5700
5761
5701
- // / Return true if `transpose` does not permute a pair of non-unit dims.
5702
- // / By `order preserving` we mean that the flattened versions of the input and
5703
- // / output vectors are (numerically) identical. In other words `transpose` is
5704
- // / effectively a shape cast.
5705
- static bool isOrderPreserving (TransposeOp transpose) {
5706
- ArrayRef<int64_t > permutation = transpose.getPermutation ();
5707
- VectorType sourceType = transpose.getSourceVectorType ();
5708
- ArrayRef<int64_t > inShape = sourceType.getShape ();
5709
- ArrayRef<bool > inDimIsScalable = sourceType.getScalableDims ();
5710
- auto isNonScalableUnitDim = [&](int64_t dim) {
5711
- return inShape[dim] == 1 && !inDimIsScalable[dim];
5712
- };
5713
- int64_t current = 0 ;
5714
- for (auto p : permutation) {
5715
- if (!isNonScalableUnitDim (p)) {
5716
- if (p < current) {
5717
- return false ;
5718
- }
5719
- current = p;
5720
- }
5721
- }
5722
- return true ;
5723
- }
5724
-
5725
5762
OpFoldResult ShapeCastOp::fold (FoldAdaptor adaptor) {
5726
5763
5727
5764
VectorType resultType = getType ();
@@ -5736,33 +5773,6 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5736
5773
return getResult ();
5737
5774
}
5738
5775
5739
- // shape_cast(transpose(x)) -> shape_cast(x)
5740
- if (auto transpose = getSource ().getDefiningOp <TransposeOp>()) {
5741
- // This folder does
5742
- // shape_cast(transpose) -> shape_cast
5743
- // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
5744
- // shape_cast -> shape_cast(transpose)
5745
- // i.e. the complete opposite. When paired, these 2 patterns can cause
5746
- // infinite cycles in pattern rewriting.
5747
- // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
5748
- // vectors, so by disabling this folder for scalable vectors the
5749
- // cycle is avoided.
5750
- // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
5751
- // still needed. If it's not, then we can fold here.
5752
- if (!transpose.getType ().isScalable () && isOrderPreserving (transpose)) {
5753
- setOperand (transpose.getVector ());
5754
- return getResult ();
5755
- }
5756
- return {};
5757
- }
5758
-
5759
- // Y = shape_cast(broadcast(X))
5760
- // -> X, if X and Y have same type
5761
- if (auto bcastOp = getSource ().getDefiningOp <BroadcastOp>()) {
5762
- if (bcastOp.getSourceType () == resultType)
5763
- return bcastOp.getSource ();
5764
- }
5765
-
5766
5776
// shape_cast(constant) -> constant
5767
5777
if (auto splatAttr =
5768
5778
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource ()))
@@ -5884,10 +5894,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
5884
5894
}
5885
5895
};
5886
5896
5887
- // / Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
5888
- // / i) Y = ShapeCast(X), or
5889
- // / ii) Y = Broadcast(X)
5890
- // / If both (i) and (ii) are possible, (i) is chosen.
5897
+ // / Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
5891
5898
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
5892
5899
public:
5893
5900
using OpRewritePattern::OpRewritePattern;
@@ -5902,22 +5909,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
5902
5909
auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType ());
5903
5910
bool srcIsScalar = !srcVectorType;
5904
5911
5905
- // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
5906
- // Example:
5907
- // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
5908
- // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
5909
- // to
5910
- // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
5911
- if (srcVectorType) {
5912
- if (srcVectorType.getNumElements () ==
5913
- shapeCastOp.getResultVectorType ().getNumElements ()) {
5914
- rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(
5915
- shapeCastOp, shapeCastOp.getResultVectorType (),
5916
- broadcastOp.getSource ());
5917
- return success ();
5918
- }
5919
- }
5920
-
5921
5912
// Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
5922
5913
// Example
5923
5914
// %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
@@ -6118,21 +6109,6 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6118
6109
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector ()))
6119
6110
return ub::PoisonAttr::get (getContext ());
6120
6111
6121
- // Eliminate identity transposes, and more generally any transposes that
6122
- // preserves the shape without permuting elements.
6123
- //
6124
- // Examples of what to fold:
6125
- // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
6126
- // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
6127
- // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
6128
- //
6129
- // Example of what NOT to fold:
6130
- // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
6131
- //
6132
- if (getSourceVectorType () == getResultVectorType () &&
6133
- isOrderPreserving (*this ))
6134
- return getVector ();
6135
-
6136
6112
return {};
6137
6113
}
6138
6114
@@ -6252,32 +6228,6 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
6252
6228
}
6253
6229
};
6254
6230
6255
- // / Folds transpose(shape_cast) into a new shape_cast.
6256
- class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
6257
- public:
6258
- using OpRewritePattern::OpRewritePattern;
6259
-
6260
- LogicalResult matchAndRewrite (TransposeOp transposeOp,
6261
- PatternRewriter &rewriter) const override {
6262
- auto shapeCastOp =
6263
- transposeOp.getVector ().getDefiningOp <vector::ShapeCastOp>();
6264
- if (!shapeCastOp)
6265
- return failure ();
6266
- if (!isOrderPreserving (transposeOp))
6267
- return failure ();
6268
-
6269
- VectorType resultType = transposeOp.getType ();
6270
-
6271
- // We don't need to check isValidShapeCast at this point, because it is
6272
- // guaranteed that merging the transpose into the the shape_cast is a valid
6273
- // shape_cast, because the transpose just inserts/removes ones.
6274
-
6275
- rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(transposeOp, resultType,
6276
- shapeCastOp.getSource ());
6277
- return success ();
6278
- }
6279
- };
6280
-
6281
6231
// / Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
6282
6232
// / 'order preserving', where 'order preserving' means the flattened
6283
6233
// / inputs and outputs of the transpose have identical (numerical) values.
@@ -6373,12 +6323,73 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
6373
6323
}
6374
6324
};
6375
6325
6326
+ // / Return true if `transpose` does not permute a pair of non-unit dims.
6327
+ // / By `order preserving` we mean that the flattened versions of the input and
6328
+ // / output vectors are (numerically) identical. In other words `transpose` is
6329
+ // / effectively a shape cast.
6330
+ static bool isOrderPreserving (TransposeOp transpose) {
6331
+ ArrayRef<int64_t > permutation = transpose.getPermutation ();
6332
+ VectorType sourceType = transpose.getSourceVectorType ();
6333
+ ArrayRef<int64_t > inShape = sourceType.getShape ();
6334
+ ArrayRef<bool > inDimIsScalable = sourceType.getScalableDims ();
6335
+ auto isNonScalableUnitDim = [&](int64_t dim) {
6336
+ return inShape[dim] == 1 && !inDimIsScalable[dim];
6337
+ };
6338
+ int64_t current = 0 ;
6339
+ for (auto p : permutation) {
6340
+ if (!isNonScalableUnitDim (p)) {
6341
+ if (p < current) {
6342
+ return false ;
6343
+ }
6344
+ current = p;
6345
+ }
6346
+ }
6347
+ return true ;
6348
+ }
6349
+
6350
+ // / For example,
6351
+ // / ```
6352
+ // / %0 = vector.transpose %arg0, [0, 2, 1] :
6353
+ // / vector<2x1x2xf32> to vector<2x2x1xf32>
6354
+ // / ```
6355
+ // / becomes
6356
+ // / ```
6357
+ // / %0 = vector.shape_cast %arg0 :
6358
+ // / vector<2x1x2xf32> to vector<2x2x1xf32>
6359
+ // / ```
6360
+ struct TransposeToShapeCast final
6361
+ : public OpRewritePattern<vector::TransposeOp> {
6362
+ using OpRewritePattern::OpRewritePattern;
6363
+ LogicalResult matchAndRewrite (vector::TransposeOp transpose,
6364
+ PatternRewriter &rewriter) const override {
6365
+
6366
+ // This folder does
6367
+ // shape_cast(transpose) -> shape_cast
6368
+ // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
6369
+ // shape_cast -> shape_cast(transpose)
6370
+ // i.e. the complete opposite. When paired, these 2 patterns can cause
6371
+ // infinite cycles in pattern rewriting.
6372
+ // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
6373
+ // vectors, so by disabling this folder for scalable vectors the
6374
+ // cycle is avoided.
6375
+ // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
6376
+ // still needed. If it's not, then we can fold here.
6377
+ if (!isOrderPreserving (transpose) || transpose.getType ().isScalable ()) {
6378
+ return rewriter.notifyMatchFailure (
6379
+ transpose, " not order preserving, so not semantically a 'copy'" );
6380
+ }
6381
+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(
6382
+ transpose, transpose.getType (), transpose.getVector ());
6383
+ return success ();
6384
+ }
6385
+ };
6386
+
6376
6387
} // namespace
6377
6388
6378
6389
void vector::TransposeOp::getCanonicalizationPatterns (
6379
6390
RewritePatternSet &results, MLIRContext *context) {
6380
- results.add <FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder ,
6381
- FoldTransposeSplat, FoldTransposeBroadcast >(context);
6391
+ results.add <FoldTransposeCreateMask, TransposeFolder, FoldTransposeSplat ,
6392
+ FoldTransposeBroadcast, TransposeToShapeCast >(context);
6382
6393
}
6383
6394
6384
6395
// ===----------------------------------------------------------------------===//
0 commit comments