-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[Vector] Add canonicalization for select(pred, true, false) -> broadcast(pred) #147934
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-vector Author: Kunwar Grover (Groverkss) ChangesFull diff: https://github.com/llvm/llvm-project/pull/147934.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1fb8c7a928e06..39c8191e8451a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2913,13 +2913,74 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
return success();
}
};
+
+/// true: vector
+/// false: vector
+/// pred: i1
+///
+/// select(pred, true, false) -> broadcast(pred)
+/// select(pred, false, true) -> broadcast(not(pred))
+///
+/// Ideally, this would be a canonicalization pattern on arith::SelectOp, but
+/// we cannot have arith depending on vector. Also, it would implicitly force
+/// users only using arith and vector dialect to use vector dialect. Instead,
+/// this canonicalization only runs if vector::BroadcastOp was a registered
+/// operation.
+struct FoldI1SelectToBroadcast : public OpRewritePattern<arith::SelectOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arith::SelectOp selectOp,
+ PatternRewriter &rewriter) const override {
+ auto vecType = dyn_cast<VectorType>(selectOp.getType());
+ if (!vecType || !vecType.getElementType().isInteger(1))
+ return failure();
+
+ // Vector conditionals do not need broadcast and are already handled by
+ // the arith.select folder.
+ Value pred = selectOp.getCondition();
+ if (isa<VectorType>(pred.getType()))
+ return failure();
+
+ std::optional<int64_t> trueInt =
+ getConstantIntValue(selectOp.getTrueValue());
+ std::optional<int64_t> falseInt =
+ getConstantIntValue(selectOp.getFalseValue());
+ if (!trueInt || !falseInt)
+ return failure();
+
+ // Redundant selects are already handled by arith.select canonicalizations.
+ if (trueInt.value() == falseInt.value()) {
+ return failure();
+ }
+
+ // The only remaining possibilities are:
+ //
+ // select(pred, true, false)
+ // select(pred, false, true)
+
+ // select(pred, false, true) -> select(not(pred), true, false)
+ if (trueInt.value() == 0) {
+ Value one = rewriter.create<arith::ConstantIntOp>(
+ selectOp.getLoc(), /*value=*/1, /*width=*/1);
+ pred = rewriter.create<arith::XOrIOp>(selectOp.getLoc(), pred, one);
+ }
+
+ /// select(pred, true, false) -> broadcast(pred)
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ selectOp, vecType.clone(rewriter.getI1Type()), pred);
+ return success();
+
+ return failure();
+ }
+};
+
} // namespace
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by
// calling `populateCastAwayVectorLeadingOneDimPatterns`
- results.add<BroadcastFolder>(context);
+ results.add<BroadcastFolder, FoldI1SelectToBroadcast>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0282e9cac5e02..5924e7ea856c4 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1057,6 +1057,38 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)
// -----
+// CHECK-LABEL: func.func @canonicalize_i1_select_to_broadcast
+// CHECK-SAME: (%[[PRED:.+]]: i1)
+// CHECK: vector.broadcast %[[PRED]] : i1 to vector<4xi1>
+func.func @canonicalize_i1_select_to_broadcast(%pred: i1) -> vector<4xi1> {
+ %true = arith.constant dense<true> : vector<4x4xi1>
+ %false = arith.constant dense<false> : vector<4x4xi1>
+ %selected = arith.select %pred, %true, %false : vector<4x4xi1>
+ // The select -> broadcast pattern only loads if vector dialect was loaded.
+ // Force loading vector dialect by adding a vector operation.
+ %vec = vector.extract %selected[0] : vector<4xi1> from vector<4x4xi1>
+ return %vec : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @canonicalize_i1_select_to_not_broadcast
+// CHECK-SAME: (%[[PRED:.+]]: i1)
+// CHECK: %[[TRUE:.+]] = arith.constant true
+// CHECK: %[[NOT:.+]] = arith.xori %[[PRED]], %[[TRUE]] : i1
+// CHECK: vector.broadcast %[[NOT]] : i1 to vector<4xi1>
+func.func @canonicalize_i1_select_to_not_broadcast(%pred: i1) -> vector<4xi1> {
+ %true = arith.constant dense<true> : vector<4x4xi1>
+ %false = arith.constant dense<false> : vector<4x4xi1>
+ %selected = arith.select %pred, %false, %true : vector<4x4xi1>
+ // The select -> broadcast pattern only loads if vector dialect was loaded.
+ // Force loading vector dialect by adding a vector operation.
+ %vec = vector.extract %selected[0] : vector<4xi1> from vector<4x4xi1>
+ return %vec : vector<4xi1>
+}
+
+// -----
+
// CHECK-LABEL: fold_vector_transfer_masks
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
|
@llvm/pr-subscribers-mlir Author: Kunwar Grover (Groverkss) ChangesFull diff: https://github.com/llvm/llvm-project/pull/147934.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1fb8c7a928e06..39c8191e8451a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2913,13 +2913,74 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
return success();
}
};
+
+/// true: vector
+/// false: vector
+/// pred: i1
+///
+/// select(pred, true, false) -> broadcast(pred)
+/// select(pred, false, true) -> broadcast(not(pred))
+///
+/// Ideally, this would be a canonicalization pattern on arith::SelectOp, but
+/// we cannot have arith depending on vector. Also, it would implicitly force
+/// users only using arith and vector dialect to use vector dialect. Instead,
+/// this canonicalization only runs if vector::BroadcastOp was a registered
+/// operation.
+struct FoldI1SelectToBroadcast : public OpRewritePattern<arith::SelectOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arith::SelectOp selectOp,
+ PatternRewriter &rewriter) const override {
+ auto vecType = dyn_cast<VectorType>(selectOp.getType());
+ if (!vecType || !vecType.getElementType().isInteger(1))
+ return failure();
+
+ // Vector conditionals do not need broadcast and are already handled by
+ // the arith.select folder.
+ Value pred = selectOp.getCondition();
+ if (isa<VectorType>(pred.getType()))
+ return failure();
+
+ std::optional<int64_t> trueInt =
+ getConstantIntValue(selectOp.getTrueValue());
+ std::optional<int64_t> falseInt =
+ getConstantIntValue(selectOp.getFalseValue());
+ if (!trueInt || !falseInt)
+ return failure();
+
+ // Redundant selects are already handled by arith.select canonicalizations.
+ if (trueInt.value() == falseInt.value()) {
+ return failure();
+ }
+
+ // The only remaining possibilities are:
+ //
+ // select(pred, true, false)
+ // select(pred, false, true)
+
+ // select(pred, false, true) -> select(not(pred), true, false)
+ if (trueInt.value() == 0) {
+ Value one = rewriter.create<arith::ConstantIntOp>(
+ selectOp.getLoc(), /*value=*/1, /*width=*/1);
+ pred = rewriter.create<arith::XOrIOp>(selectOp.getLoc(), pred, one);
+ }
+
+ /// select(pred, true, false) -> broadcast(pred)
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ selectOp, vecType.clone(rewriter.getI1Type()), pred);
+ return success();
+
+ return failure();
+ }
+};
+
} // namespace
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by
// calling `populateCastAwayVectorLeadingOneDimPatterns`
- results.add<BroadcastFolder>(context);
+ results.add<BroadcastFolder, FoldI1SelectToBroadcast>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0282e9cac5e02..5924e7ea856c4 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1057,6 +1057,38 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)
// -----
+// CHECK-LABEL: func.func @canonicalize_i1_select_to_broadcast
+// CHECK-SAME: (%[[PRED:.+]]: i1)
+// CHECK: vector.broadcast %[[PRED]] : i1 to vector<4xi1>
+func.func @canonicalize_i1_select_to_broadcast(%pred: i1) -> vector<4xi1> {
+ %true = arith.constant dense<true> : vector<4x4xi1>
+ %false = arith.constant dense<false> : vector<4x4xi1>
+ %selected = arith.select %pred, %true, %false : vector<4x4xi1>
+ // The select -> broadcast pattern only loads if vector dialect was loaded.
+ // Force loading vector dialect by adding a vector operation.
+ %vec = vector.extract %selected[0] : vector<4xi1> from vector<4x4xi1>
+ return %vec : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @canonicalize_i1_select_to_not_broadcast
+// CHECK-SAME: (%[[PRED:.+]]: i1)
+// CHECK: %[[TRUE:.+]] = arith.constant true
+// CHECK: %[[NOT:.+]] = arith.xori %[[PRED]], %[[TRUE]] : i1
+// CHECK: vector.broadcast %[[NOT]] : i1 to vector<4xi1>
+func.func @canonicalize_i1_select_to_not_broadcast(%pred: i1) -> vector<4xi1> {
+ %true = arith.constant dense<true> : vector<4x4xi1>
+ %false = arith.constant dense<false> : vector<4x4xi1>
+ %selected = arith.select %pred, %false, %true : vector<4x4xi1>
+ // The select -> broadcast pattern only loads if vector dialect was loaded.
+ // Force loading vector dialect by adding a vector operation.
+ %vec = vector.extract %selected[0] : vector<4xi1> from vector<4x4xi1>
+ return %vec : vector<4xi1>
+}
+
+// -----
+
// CHECK-LABEL: fold_vector_transfer_masks
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
|
Ouch, we're suffering that the vector type is disconnected from the vector dialect. Here the main concern I have is that running a pass pipeline |
I agree... it stems from the fact that we don't have a way to splat a value to a vector in the arith dialect, but we handle things like:
we really have no way of converting a scalar to vector without going through vector dialect. We miss a bunch of canonicalizations that llvm does in InstCombine for selects because of this: llvm would convert that select into:
and the existing select folders would fold the select as needed. I'm not sure what the best way of progressing forward is here. Personally, based on the current design of arith dialect and it's inability to splat to a vector, this seems to be most fitting way for now. |
|
||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dead return
return failure(); |
} // namespace | ||
|
||
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, | ||
MLIRContext *context) { | ||
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by | ||
// calling `populateCastAwayVectorLeadingOneDimPatterns` | ||
results.add<BroadcastFolder>(context); | ||
results.add<BroadcastFolder, FoldI1SelectToBroadcast>(context); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have some precedent for canon patterns hooked up to the op they produce?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is quite counter-intuitive :( But not that uncommon:
llvm-project/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Lines 5357 to 5360 in 5f1141d
void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results, | |
MLIRContext *context) { | |
results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context); | |
} |
SwapExtractSliceOfTransferWrite
is a canonicalization for xfer_write that matches on tensor.insert_slice
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Adds a canonicalization pattern for vector.broadcast to fold:
Hm, to me this is "rewriting" arith.select
as vector.broadcast
(as in, intuitively this isn't really folding). Why is vector.broadcast
more desirable than arith.select
?
Lets discuss this first - you can skip my other comments in the meantime.
/// true: vector | ||
/// false: vector |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// true: vector | |
/// false: vector | |
/// true: vector of i1 | |
/// false: vector of i1 |
/// | ||
/// Ideally, this would be a canonicalization pattern on arith::SelectOp, but | ||
/// we cannot have arith depending on vector. Also, it would implicitly force | ||
/// users only using arith and vector dialect to use vector dialect. Instead, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo?
/// users only using arith and vector dialect to use vector dialect. Instead, | |
/// users only using arith to use vector dialect. Instead, |
} // namespace | ||
|
||
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, | ||
MLIRContext *context) { | ||
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by | ||
// calling `populateCastAwayVectorLeadingOneDimPatterns` | ||
results.add<BroadcastFolder>(context); | ||
results.add<BroadcastFolder, FoldI1SelectToBroadcast>(context); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is quite counter-intuitive :( But not that uncommon:
llvm-project/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Lines 5357 to 5360 in 5f1141d
void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results, | |
MLIRContext *context) { | |
results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context); | |
} |
SwapExtractSliceOfTransferWrite
is a canonicalization for xfer_write that matches on tensor.insert_slice
.
/// select(pred, true, false) -> broadcast(pred) | ||
/// select(pred, false, true) -> broadcast(not(pred)) | ||
/// | ||
/// Ideally, this would be a canonicalization pattern on arith::SelectOp, but |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that this paragraph makes more sense near where you register this canonicalization.
arith.select evaluates the condition and selects between the true and false vector. broadcasts directly propagates the condition forward. It also unlocks multiple folding opportunities. For example:
here, in the current ir, it is quite hard to actually reason if the vector.maskedload mask is either true or false (and not partially true/false), which prevents folding opportunities. On the other hand:
will fold into:
and now it is easy to match that the mask is just a broadcasted conditional, so it must be either always full or always empty. vector.broadcast has very clear semantics that the entire vector has the same value, on the other hand, for select, we actually have to traverse the true/false values to check if the entire vector has the same value. I will also note that LLVM InstCombine does a similar canonicalization as noted above: https://github.com/llvm/llvm-project/blob/main/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp#L2351 The LLVM canonicalization converts the i1 pred into a broadcasted vector i1 pred, which further just folds away directly, completing the same fold we are doing here. Personally, for me, this is a very clear canonicalization. vector.broadcast has much more restricted semantics vs arith.select and composes better with other vector operations (unless we plan to write folders for every vector operation that interacts with broadcast, to also work with (select pred, true , false) which is very counter intuitive). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % nits after the comment from @banach-space showing we have more canon patterns that hook into target ops
Adds a canonicalization pattern for vector.broadcast to fold:
select(pred, true, false) -> broadcast(pred)
select(pred, true, false) -> broadcast(not(pred))
This pattern should have been ideally registered as a canonicalization for arith.select, but we cannot make arith dialect canonicalizations depend on vector dialect, so it's added as a vector.broadcast canonicalization instead.