Skip to content

Commit 0909b5b

Browse files
committed
use shape_cast as canonical type for extract broadcast and transpose
1 parent e8a2ce1 commit 0909b5b

File tree

8 files changed

+291
-195
lines changed

8 files changed

+291
-195
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 128 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -2344,11 +2344,45 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
23442344
return success();
23452345
}
23462346

2347+
/// For example,
2348+
/// ```
2349+
/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
2350+
/// ```
2351+
/// becomes
2352+
/// ```
2353+
/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
2354+
/// ```
2355+
struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
2356+
using OpRewritePattern::OpRewritePattern;
2357+
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2358+
PatternRewriter &rewriter) const override {
2359+
VectorType sourceType = extractOp.getSourceVectorType();
2360+
VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2361+
if (!outType)
2362+
return failure();
2363+
2364+
// Negative values in `position` indicates poison, cannot convert to
2365+
// shape_cast
2366+
if (llvm::any_of(extractOp.getMixedPosition(),
2367+
[](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
2368+
return failure();
2369+
2370+
if (sourceType.getNumElements() != outType.getNumElements())
2371+
return failure();
2372+
2373+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
2374+
extractOp.getVector());
2375+
return success();
2376+
}
2377+
};
2378+
23472379
} // namespace
23482380

23492381
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
23502382
MLIRContext *context) {
2351-
results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2383+
results
2384+
.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
2385+
context);
23522386
results.add(foldExtractFromShapeCastToShapeCast);
23532387
results.add(foldExtractFromFromElements);
23542388
}
@@ -2651,13 +2685,40 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
26512685
return success();
26522686
}
26532687
};
2688+
2689+
/// For example,
2690+
/// ```
2691+
/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
2692+
/// ```
2693+
/// becomes
2694+
/// ```
2695+
/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
2696+
/// ```
2697+
struct BroadcastToShapeCast final
2698+
: public OpRewritePattern<vector::BroadcastOp> {
2699+
using OpRewritePattern::OpRewritePattern;
2700+
LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
2701+
PatternRewriter &rewriter) const override {
2702+
auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
2703+
if (!sourceType) {
2704+
return rewriter.notifyMatchFailure(
2705+
broadcast, "source is a scalar, shape_cast doesn't support scalar");
2706+
}
2707+
2708+
VectorType outType = broadcast.getType();
2709+
if (sourceType.getNumElements() != outType.getNumElements())
2710+
return failure();
2711+
2712+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
2713+
broadcast.getSource());
2714+
return success();
2715+
}
2716+
};
26542717
} // namespace
26552718

26562719
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
26572720
MLIRContext *context) {
2658-
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2659-
// calling `populateCastAwayVectorLeadingOneDimPatterns`
2660-
results.add<BroadcastFolder>(context);
2721+
results.add<BroadcastFolder, BroadcastToShapeCast>(context);
26612722
}
26622723

26632724
//===----------------------------------------------------------------------===//
@@ -5573,30 +5634,6 @@ LogicalResult ShapeCastOp::verify() {
55735634
return success();
55745635
}
55755636

5576-
/// Return true if `transpose` does not permute a pair of non-unit dims.
5577-
/// By `order preserving` we mean that the flattened versions of the input and
5578-
/// output vectors are (numerically) identical. In other words `transpose` is
5579-
/// effectively a shape cast.
5580-
static bool isOrderPreserving(TransposeOp transpose) {
5581-
ArrayRef<int64_t> permutation = transpose.getPermutation();
5582-
VectorType sourceType = transpose.getSourceVectorType();
5583-
ArrayRef<int64_t> inShape = sourceType.getShape();
5584-
ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
5585-
auto isNonScalableUnitDim = [&](int64_t dim) {
5586-
return inShape[dim] == 1 && !inDimIsScalable[dim];
5587-
};
5588-
int64_t current = 0;
5589-
for (auto p : permutation) {
5590-
if (!isNonScalableUnitDim(p)) {
5591-
if (p < current) {
5592-
return false;
5593-
}
5594-
current = p;
5595-
}
5596-
}
5597-
return true;
5598-
}
5599-
56005637
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56015638

56025639
VectorType resultType = getType();
@@ -5611,33 +5648,6 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56115648
return getResult();
56125649
}
56135650

5614-
// shape_cast(transpose(x)) -> shape_cast(x)
5615-
if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
5616-
// This folder does
5617-
// shape_cast(transpose) -> shape_cast
5618-
// But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
5619-
// shape_cast -> shape_cast(transpose)
5620-
// i.e. the complete opposite. When paired, these 2 patterns can cause
5621-
// infinite cycles in pattern rewriting.
5622-
// ConvertIllegalShapeCastOpsToTransposes only matches on scalable
5623-
// vectors, so by disabling this folder for scalable vectors the
5624-
// cycle is avoided.
5625-
// TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
5626-
// still needed. If it's not, then we can fold here.
5627-
if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) {
5628-
setOperand(transpose.getVector());
5629-
return getResult();
5630-
}
5631-
return {};
5632-
}
5633-
5634-
// Y = shape_cast(broadcast(X))
5635-
// -> X, if X and Y have same type
5636-
if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5637-
if (bcastOp.getSourceType() == resultType)
5638-
return bcastOp.getSource();
5639-
}
5640-
56415651
// shape_cast(constant) -> constant
56425652
if (auto splatAttr =
56435653
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
@@ -5993,21 +6003,6 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
59936003
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
59946004
return ub::PoisonAttr::get(getContext());
59956005

5996-
// Eliminate identity transposes, and more generally any transposes that
5997-
// preserves the shape without permuting elements.
5998-
//
5999-
// Examples of what to fold:
6000-
// %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
6001-
// %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
6002-
// %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
6003-
//
6004-
// Example of what NOT to fold:
6005-
// %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
6006-
//
6007-
if (getSourceVectorType() == getResultVectorType() &&
6008-
isOrderPreserving(*this))
6009-
return getVector();
6010-
60116006
return {};
60126007
}
60136008

@@ -6127,32 +6122,6 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
61276122
}
61286123
};
61296124

6130-
/// Folds transpose(shape_cast) into a new shape_cast.
6131-
class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
6132-
public:
6133-
using OpRewritePattern::OpRewritePattern;
6134-
6135-
LogicalResult matchAndRewrite(TransposeOp transposeOp,
6136-
PatternRewriter &rewriter) const override {
6137-
auto shapeCastOp =
6138-
transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6139-
if (!shapeCastOp)
6140-
return failure();
6141-
if (!isOrderPreserving(transposeOp))
6142-
return failure();
6143-
6144-
VectorType resultType = transposeOp.getType();
6145-
6146-
// We don't need to check isValidShapeCast at this point, because it is
6147-
// guaranteed that merging the transpose into the the shape_cast is a valid
6148-
// shape_cast, because the transpose just inserts/removes ones.
6149-
6150-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
6151-
shapeCastOp.getSource());
6152-
return success();
6153-
}
6154-
};
6155-
61566125
/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
61576126
/// 'order preserving', where 'order preserving' means the flattened
61586127
/// inputs and outputs of the transpose have identical (numerical) values.
@@ -6248,12 +6217,73 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
62486217
}
62496218
};
62506219

6220+
/// Return true if `transpose` does not permute a pair of non-unit dims.
6221+
/// By `order preserving` we mean that the flattened versions of the input and
6222+
/// output vectors are (numerically) identical. In other words `transpose` is
6223+
/// effectively a shape cast.
6224+
static bool isOrderPreserving(TransposeOp transpose) {
6225+
ArrayRef<int64_t> permutation = transpose.getPermutation();
6226+
VectorType sourceType = transpose.getSourceVectorType();
6227+
ArrayRef<int64_t> inShape = sourceType.getShape();
6228+
ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6229+
auto isNonScalableUnitDim = [&](int64_t dim) {
6230+
return inShape[dim] == 1 && !inDimIsScalable[dim];
6231+
};
6232+
int64_t current = 0;
6233+
for (auto p : permutation) {
6234+
if (!isNonScalableUnitDim(p)) {
6235+
if (p < current) {
6236+
return false;
6237+
}
6238+
current = p;
6239+
}
6240+
}
6241+
return true;
6242+
}
6243+
6244+
/// For example,
6245+
/// ```
6246+
/// %0 = vector.transpose %arg0, [0, 2, 1] :
6247+
/// vector<2x1x2xf32> to vector<2x2x1xf32>
6248+
/// ```
6249+
/// becomes
6250+
/// ```
6251+
/// %0 = vector.shape_cast %arg0 :
6252+
/// vector<2x1x2xf32> to vector<2x2x1xf32>
6253+
/// ```
6254+
struct TransposeToShapeCast final
6255+
: public OpRewritePattern<vector::TransposeOp> {
6256+
using OpRewritePattern::OpRewritePattern;
6257+
LogicalResult matchAndRewrite(vector::TransposeOp transpose,
6258+
PatternRewriter &rewriter) const override {
6259+
6260+
// This folder does
6261+
// shape_cast(transpose) -> shape_cast
6262+
// But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
6263+
// shape_cast -> shape_cast(transpose)
6264+
// i.e. the complete opposite. When paired, these 2 patterns can cause
6265+
// infinite cycles in pattern rewriting.
6266+
// ConvertIllegalShapeCastOpsToTransposes only matches on scalable
6267+
// vectors, so by disabling this folder for scalable vectors the
6268+
// cycle is avoided.
6269+
// TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
6270+
// still needed. If it's not, then we can fold here.
6271+
if (!isOrderPreserving(transpose) || transpose.getType().isScalable()) {
6272+
return rewriter.notifyMatchFailure(
6273+
transpose, "not order preserving, so not semantically a 'copy'");
6274+
}
6275+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
6276+
transpose, transpose.getType(), transpose.getVector());
6277+
return success();
6278+
}
6279+
};
6280+
62516281
} // namespace
62526282

62536283
void vector::TransposeOp::getCanonicalizationPatterns(
62546284
RewritePatternSet &results, MLIRContext *context) {
6255-
results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
6256-
FoldTransposeSplat, FoldTransposeBroadcast>(context);
6285+
results.add<FoldTransposeCreateMask, TransposeFolder, FoldTransposeSplat,
6286+
FoldTransposeBroadcast, TransposeToShapeCast>(context);
62576287
}
62586288

62596289
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -754,11 +754,11 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
754754
// -----
755755

756756
// CHECK-LABEL: fold_extract_broadcast_negative
757-
// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
758-
// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
757+
// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x2x4xf32>
758+
// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x2x4xf32>
759759
func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32> {
760-
%b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
761-
%r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
760+
%b = vector.broadcast %a : vector<1x1xf32> to vector<1x2x4xf32>
761+
%r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x2x4xf32>
762762
return %r : vector<4xf32>
763763
}
764764

@@ -797,8 +797,8 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
797797
// rank(extract_output) < rank(broadcast_input)
798798
func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
799799
%idx0 : index, %idx1 : index) -> vector<4xf32> {
800-
%b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
801-
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
800+
%b = vector.broadcast %a : vector<2x4xf32> to vector<2x2x4xf32>
801+
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<2x2x4xf32>
802802
return %r : vector<4xf32>
803803
}
804804

@@ -1840,12 +1840,12 @@ func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
18401840

18411841
// -----
18421842

1843-
// CHECK-LABEL: func @insert_extract_to_broadcast
1843+
// CHECK-LABEL: func @insert_extract_to_shape_cast
18441844
// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
1845-
// CHECK: %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
1846-
// CHECK: %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
1845+
// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
1846+
// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
18471847
// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
1848-
func.func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>,
1848+
func.func @insert_extract_to_shape_cast(%arg0 : vector<1x1x4xf32>,
18491849
%arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
18501850
%0 = vector.extract %arg0[0, 0] : vector<4xf32> from vector<1x1x4xf32>
18511851
%1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
@@ -2197,7 +2197,7 @@ func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
21972197

21982198
// CHECK-LABEL: func @shuffle_canonicalize_0d
21992199
func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
2200-
// CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
2200+
// CHECK: vector.shape_cast %{{.*}} : vector<i32> to vector<1xi32>
22012201
%shuffle = vector.shuffle %v0, %v1 [0] : vector<i32>, vector<i32>
22022202
return %shuffle : vector<1xi32>
22032203
}
@@ -2684,9 +2684,8 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf
26842684
// CHECK-LABEL: func.func @extract_from_broadcast
26852685
func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
26862686
%0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
2687-
2688-
// CHECK-NEXT: %0 = vector.extract {{.*}}[0, 0] : vector<1xf32> from vector<1x1x1xf32>
2689-
// CHECK-NEXT: return %0 : vector<1xf32>
2687+
// CHECK-NEXT: %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32>
2688+
// CHECK-NEXT: return %[[RES]] : vector<1xf32>
26902689
%1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32>
26912690
return %1: vector<1xf32>
26922691
}

mlir/test/Dialect/Vector/canonicalize/playtime.mlir

Whitespace-only changes.

0 commit comments

Comments
 (0)