@@ -2344,11 +2344,45 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2344
2344
return success ();
2345
2345
}
2346
2346
2347
+ // / For example,
2348
+ // / ```
2349
+ // / %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
2350
+ // / ```
2351
+ // / becomes
2352
+ // / ```
2353
+ // / %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
2354
+ // / ```
2355
+ struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
2356
+ using OpRewritePattern::OpRewritePattern;
2357
+ LogicalResult matchAndRewrite (vector::ExtractOp extractOp,
2358
+ PatternRewriter &rewriter) const override {
2359
+ VectorType sourceType = extractOp.getSourceVectorType ();
2360
+ VectorType outType = dyn_cast<VectorType>(extractOp.getType ());
2361
+ if (!outType)
2362
+ return failure ();
2363
+
2364
+ // Negative values in `position` indicates poison, cannot convert to
2365
+ // shape_cast
2366
+ if (llvm::any_of (extractOp.getMixedPosition (),
2367
+ [](OpFoldResult v) { return !isConstantIntValue (v, 0 ); }))
2368
+ return failure ();
2369
+
2370
+ if (sourceType.getNumElements () != outType.getNumElements ())
2371
+ return failure ();
2372
+
2373
+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(extractOp, outType,
2374
+ extractOp.getVector ());
2375
+ return success ();
2376
+ }
2377
+ };
2378
+
2347
2379
} // namespace
2348
2380
2349
2381
void ExtractOp::getCanonicalizationPatterns (RewritePatternSet &results,
2350
2382
MLIRContext *context) {
2351
- results.add <ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2383
+ results
2384
+ .add <ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
2385
+ context);
2352
2386
results.add (foldExtractFromShapeCastToShapeCast);
2353
2387
results.add (foldExtractFromFromElements);
2354
2388
}
@@ -2651,13 +2685,40 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
2651
2685
return success ();
2652
2686
}
2653
2687
};
2688
+
2689
+ // / For example,
2690
+ // / ```
2691
+ // / %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
2692
+ // / ```
2693
+ // / becomes
2694
+ // / ```
2695
+ // / %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
2696
+ // / ```
2697
+ struct BroadcastToShapeCast final
2698
+ : public OpRewritePattern<vector::BroadcastOp> {
2699
+ using OpRewritePattern::OpRewritePattern;
2700
+ LogicalResult matchAndRewrite (vector::BroadcastOp broadcast,
2701
+ PatternRewriter &rewriter) const override {
2702
+ auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType ());
2703
+ if (!sourceType) {
2704
+ return rewriter.notifyMatchFailure (
2705
+ broadcast, " source is a scalar, shape_cast doesn't support scalar" );
2706
+ }
2707
+
2708
+ VectorType outType = broadcast.getType ();
2709
+ if (sourceType.getNumElements () != outType.getNumElements ())
2710
+ return failure ();
2711
+
2712
+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(broadcast, outType,
2713
+ broadcast.getSource ());
2714
+ return success ();
2715
+ }
2716
+ };
2654
2717
} // namespace
2655
2718
2656
2719
void BroadcastOp::getCanonicalizationPatterns (RewritePatternSet &results,
2657
2720
MLIRContext *context) {
2658
- // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2659
- // calling `populateCastAwayVectorLeadingOneDimPatterns`
2660
- results.add <BroadcastFolder>(context);
2721
+ results.add <BroadcastFolder, BroadcastToShapeCast>(context);
2661
2722
}
2662
2723
2663
2724
// ===----------------------------------------------------------------------===//
@@ -5573,30 +5634,6 @@ LogicalResult ShapeCastOp::verify() {
5573
5634
return success ();
5574
5635
}
5575
5636
5576
- // / Return true if `transpose` does not permute a pair of non-unit dims.
5577
- // / By `order preserving` we mean that the flattened versions of the input and
5578
- // / output vectors are (numerically) identical. In other words `transpose` is
5579
- // / effectively a shape cast.
5580
- static bool isOrderPreserving (TransposeOp transpose) {
5581
- ArrayRef<int64_t > permutation = transpose.getPermutation ();
5582
- VectorType sourceType = transpose.getSourceVectorType ();
5583
- ArrayRef<int64_t > inShape = sourceType.getShape ();
5584
- ArrayRef<bool > inDimIsScalable = sourceType.getScalableDims ();
5585
- auto isNonScalableUnitDim = [&](int64_t dim) {
5586
- return inShape[dim] == 1 && !inDimIsScalable[dim];
5587
- };
5588
- int64_t current = 0 ;
5589
- for (auto p : permutation) {
5590
- if (!isNonScalableUnitDim (p)) {
5591
- if (p < current) {
5592
- return false ;
5593
- }
5594
- current = p;
5595
- }
5596
- }
5597
- return true ;
5598
- }
5599
-
5600
5637
OpFoldResult ShapeCastOp::fold (FoldAdaptor adaptor) {
5601
5638
5602
5639
VectorType resultType = getType ();
@@ -5611,33 +5648,6 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5611
5648
return getResult ();
5612
5649
}
5613
5650
5614
- // shape_cast(transpose(x)) -> shape_cast(x)
5615
- if (auto transpose = getSource ().getDefiningOp <TransposeOp>()) {
5616
- // This folder does
5617
- // shape_cast(transpose) -> shape_cast
5618
- // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
5619
- // shape_cast -> shape_cast(transpose)
5620
- // i.e. the complete opposite. When paired, these 2 patterns can cause
5621
- // infinite cycles in pattern rewriting.
5622
- // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
5623
- // vectors, so by disabling this folder for scalable vectors the
5624
- // cycle is avoided.
5625
- // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
5626
- // still needed. If it's not, then we can fold here.
5627
- if (!transpose.getType ().isScalable () && isOrderPreserving (transpose)) {
5628
- setOperand (transpose.getVector ());
5629
- return getResult ();
5630
- }
5631
- return {};
5632
- }
5633
-
5634
- // Y = shape_cast(broadcast(X))
5635
- // -> X, if X and Y have same type
5636
- if (auto bcastOp = getSource ().getDefiningOp <BroadcastOp>()) {
5637
- if (bcastOp.getSourceType () == resultType)
5638
- return bcastOp.getSource ();
5639
- }
5640
-
5641
5651
// shape_cast(constant) -> constant
5642
5652
if (auto splatAttr =
5643
5653
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource ()))
@@ -5993,21 +6003,6 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
5993
6003
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector ()))
5994
6004
return ub::PoisonAttr::get (getContext ());
5995
6005
5996
- // Eliminate identity transposes, and more generally any transposes that
5997
- // preserves the shape without permuting elements.
5998
- //
5999
- // Examples of what to fold:
6000
- // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
6001
- // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
6002
- // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
6003
- //
6004
- // Example of what NOT to fold:
6005
- // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
6006
- //
6007
- if (getSourceVectorType () == getResultVectorType () &&
6008
- isOrderPreserving (*this ))
6009
- return getVector ();
6010
-
6011
6006
return {};
6012
6007
}
6013
6008
@@ -6127,32 +6122,6 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
6127
6122
}
6128
6123
};
6129
6124
6130
- // / Folds transpose(shape_cast) into a new shape_cast.
6131
- class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
6132
- public:
6133
- using OpRewritePattern::OpRewritePattern;
6134
-
6135
- LogicalResult matchAndRewrite (TransposeOp transposeOp,
6136
- PatternRewriter &rewriter) const override {
6137
- auto shapeCastOp =
6138
- transposeOp.getVector ().getDefiningOp <vector::ShapeCastOp>();
6139
- if (!shapeCastOp)
6140
- return failure ();
6141
- if (!isOrderPreserving (transposeOp))
6142
- return failure ();
6143
-
6144
- VectorType resultType = transposeOp.getType ();
6145
-
6146
- // We don't need to check isValidShapeCast at this point, because it is
6147
- // guaranteed that merging the transpose into the the shape_cast is a valid
6148
- // shape_cast, because the transpose just inserts/removes ones.
6149
-
6150
- rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(transposeOp, resultType,
6151
- shapeCastOp.getSource ());
6152
- return success ();
6153
- }
6154
- };
6155
-
6156
6125
// / Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
6157
6126
// / 'order preserving', where 'order preserving' means the flattened
6158
6127
// / inputs and outputs of the transpose have identical (numerical) values.
@@ -6248,12 +6217,73 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
6248
6217
}
6249
6218
};
6250
6219
6220
+ // / Return true if `transpose` does not permute a pair of non-unit dims.
6221
+ // / By `order preserving` we mean that the flattened versions of the input and
6222
+ // / output vectors are (numerically) identical. In other words `transpose` is
6223
+ // / effectively a shape cast.
6224
+ static bool isOrderPreserving (TransposeOp transpose) {
6225
+ ArrayRef<int64_t > permutation = transpose.getPermutation ();
6226
+ VectorType sourceType = transpose.getSourceVectorType ();
6227
+ ArrayRef<int64_t > inShape = sourceType.getShape ();
6228
+ ArrayRef<bool > inDimIsScalable = sourceType.getScalableDims ();
6229
+ auto isNonScalableUnitDim = [&](int64_t dim) {
6230
+ return inShape[dim] == 1 && !inDimIsScalable[dim];
6231
+ };
6232
+ int64_t current = 0 ;
6233
+ for (auto p : permutation) {
6234
+ if (!isNonScalableUnitDim (p)) {
6235
+ if (p < current) {
6236
+ return false ;
6237
+ }
6238
+ current = p;
6239
+ }
6240
+ }
6241
+ return true ;
6242
+ }
6243
+
6244
+ // / For example,
6245
+ // / ```
6246
+ // / %0 = vector.transpose %arg0, [0, 2, 1] :
6247
+ // / vector<2x1x2xf32> to vector<2x2x1xf32>
6248
+ // / ```
6249
+ // / becomes
6250
+ // / ```
6251
+ // / %0 = vector.shape_cast %arg0 :
6252
+ // / vector<2x1x2xf32> to vector<2x2x1xf32>
6253
+ // / ```
6254
+ struct TransposeToShapeCast final
6255
+ : public OpRewritePattern<vector::TransposeOp> {
6256
+ using OpRewritePattern::OpRewritePattern;
6257
+ LogicalResult matchAndRewrite (vector::TransposeOp transpose,
6258
+ PatternRewriter &rewriter) const override {
6259
+
6260
+ // This folder does
6261
+ // shape_cast(transpose) -> shape_cast
6262
+ // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
6263
+ // shape_cast -> shape_cast(transpose)
6264
+ // i.e. the complete opposite. When paired, these 2 patterns can cause
6265
+ // infinite cycles in pattern rewriting.
6266
+ // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
6267
+ // vectors, so by disabling this folder for scalable vectors the
6268
+ // cycle is avoided.
6269
+ // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
6270
+ // still needed. If it's not, then we can fold here.
6271
+ if (!isOrderPreserving (transpose) || transpose.getType ().isScalable ()) {
6272
+ return rewriter.notifyMatchFailure (
6273
+ transpose, " not order preserving, so not semantically a 'copy'" );
6274
+ }
6275
+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(
6276
+ transpose, transpose.getType (), transpose.getVector ());
6277
+ return success ();
6278
+ }
6279
+ };
6280
+
6251
6281
} // namespace
6252
6282
6253
6283
void vector::TransposeOp::getCanonicalizationPatterns (
6254
6284
RewritePatternSet &results, MLIRContext *context) {
6255
- results.add <FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder ,
6256
- FoldTransposeSplat, FoldTransposeBroadcast >(context);
6285
+ results.add <FoldTransposeCreateMask, TransposeFolder, FoldTransposeSplat ,
6286
+ FoldTransposeBroadcast, TransposeToShapeCast >(context);
6257
6287
}
6258
6288
6259
6289
// ===----------------------------------------------------------------------===//
0 commit comments