Skip to content

[Vector] Add canonicalization for select(pred, true, false) -> broadcast(pred) #147934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Groverkss
Copy link
Member

@Groverkss Groverkss commented Jul 10, 2025

Adds a canonicalization pattern for vector.broadcast to fold:

select(pred, true, false) -> broadcast(pred)
select(pred, true, false) -> broadcast(not(pred))

This pattern should have been ideally registered as a canonicalization for arith.select, but we cannot make arith dialect canonicalizations depend on vector dialect, so it's added as a vector.broadcast canonicalization instead.

@llvmbot
Copy link
Member

llvmbot commented Jul 10, 2025

@llvm/pr-subscribers-mlir-vector

Author: Kunwar Grover (Groverkss)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/147934.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+62-1)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+32)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1fb8c7a928e06..39c8191e8451a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2913,13 +2913,74 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     return success();
   }
 };
+
+/// true: vector
+/// false: vector
+/// pred: i1
+///
+/// select(pred, true, false) -> broadcast(pred)
+/// select(pred, false, true) -> broadcast(not(pred))
+///
+/// Ideally, this would be a canonicalization pattern on arith::SelectOp, but
+/// we cannot have arith depending on vector. Also, it would implicitly force
+/// users only using arith and vector dialect to use vector dialect. Instead,
+/// this canonicalization only runs if vector::BroadcastOp was a registered
+/// operation.
+struct FoldI1SelectToBroadcast : public OpRewritePattern<arith::SelectOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arith::SelectOp selectOp,
+                                PatternRewriter &rewriter) const override {
+    auto vecType = dyn_cast<VectorType>(selectOp.getType());
+    if (!vecType || !vecType.getElementType().isInteger(1))
+      return failure();
+
+    // Vector conditionals do not need broadcast and are already handled by
+    // the arith.select folder.
+    Value pred = selectOp.getCondition();
+    if (isa<VectorType>(pred.getType()))
+      return failure();
+
+    std::optional<int64_t> trueInt =
+        getConstantIntValue(selectOp.getTrueValue());
+    std::optional<int64_t> falseInt =
+        getConstantIntValue(selectOp.getFalseValue());
+    if (!trueInt || !falseInt)
+      return failure();
+
+    // Redundant selects are already handled by arith.select canonicalizations.
+    if (trueInt.value() == falseInt.value()) {
+      return failure();
+    }
+
+    // The only remaining possibilities are:
+    //
+    // select(pred, true, false)
+    // select(pred, false, true)
+
+    // select(pred, false, true) -> select(not(pred), true, false)
+    if (trueInt.value() == 0) {
+      Value one = rewriter.create<arith::ConstantIntOp>(
+          selectOp.getLoc(), /*value=*/1, /*width=*/1);
+      pred = rewriter.create<arith::XOrIOp>(selectOp.getLoc(), pred, one);
+    }
+
+    /// select(pred, true, false) -> broadcast(pred)
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+        selectOp, vecType.clone(rewriter.getI1Type()), pred);
+    return success();
+
+    return failure();
+  }
+};
+
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
   // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
   // calling `populateCastAwayVectorLeadingOneDimPatterns`
-  results.add<BroadcastFolder>(context);
+  results.add<BroadcastFolder, FoldI1SelectToBroadcast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0282e9cac5e02..5924e7ea856c4 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1057,6 +1057,38 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)
 
 // -----
 
+// CHECK-LABEL: func.func @canonicalize_i1_select_to_broadcast
+// CHECK-SAME: (%[[PRED:.+]]: i1)
+// CHECK: vector.broadcast %[[PRED]] : i1 to vector<4xi1>
+func.func @canonicalize_i1_select_to_broadcast(%pred: i1) -> vector<4xi1> {
+  %true = arith.constant dense<true> : vector<4x4xi1>
+  %false = arith.constant dense<false> : vector<4x4xi1>
+  %selected = arith.select %pred, %true, %false : vector<4x4xi1>
+  // The select -> broadcast pattern only loads if vector dialect was loaded.
+  // Force loading vector dialect by adding a vector operation.
+  %vec = vector.extract %selected[0] : vector<4xi1> from vector<4x4xi1>
+  return %vec : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @canonicalize_i1_select_to_not_broadcast
+// CHECK-SAME: (%[[PRED:.+]]: i1)
+// CHECK: %[[TRUE:.+]] = arith.constant true
+// CHECK: %[[NOT:.+]] = arith.xori %[[PRED]], %[[TRUE]] : i1
+// CHECK: vector.broadcast %[[NOT]] : i1 to vector<4xi1>
+func.func @canonicalize_i1_select_to_not_broadcast(%pred: i1) -> vector<4xi1> {
+  %true = arith.constant dense<true> : vector<4x4xi1>
+  %false = arith.constant dense<false> : vector<4x4xi1>
+  %selected = arith.select %pred, %false, %true : vector<4x4xi1>
+  // The select -> broadcast pattern only loads if vector dialect was loaded.
+  // Force loading vector dialect by adding a vector operation.
+  %vec = vector.extract %selected[0] : vector<4xi1> from vector<4x4xi1>
+  return %vec : vector<4xi1>
+}
+
+// -----
+
 // CHECK-LABEL: fold_vector_transfer_masks
 func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index

@llvmbot
Copy link
Member

llvmbot commented Jul 10, 2025

@llvm/pr-subscribers-mlir

Author: Kunwar Grover (Groverkss)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/147934.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+62-1)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+32)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1fb8c7a928e06..39c8191e8451a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2913,13 +2913,74 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     return success();
   }
 };
+
+/// true: vector
+/// false: vector
+/// pred: i1
+///
+/// select(pred, true, false) -> broadcast(pred)
+/// select(pred, false, true) -> broadcast(not(pred))
+///
+/// Ideally, this would be a canonicalization pattern on arith::SelectOp, but
+/// we cannot have arith depending on vector. Also, it would implicitly force
+/// users only using arith and vector dialect to use vector dialect. Instead,
+/// this canonicalization only runs if vector::BroadcastOp was a registered
+/// operation.
+struct FoldI1SelectToBroadcast : public OpRewritePattern<arith::SelectOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arith::SelectOp selectOp,
+                                PatternRewriter &rewriter) const override {
+    auto vecType = dyn_cast<VectorType>(selectOp.getType());
+    if (!vecType || !vecType.getElementType().isInteger(1))
+      return failure();
+
+    // Vector conditionals do not need broadcast and are already handled by
+    // the arith.select folder.
+    Value pred = selectOp.getCondition();
+    if (isa<VectorType>(pred.getType()))
+      return failure();
+
+    std::optional<int64_t> trueInt =
+        getConstantIntValue(selectOp.getTrueValue());
+    std::optional<int64_t> falseInt =
+        getConstantIntValue(selectOp.getFalseValue());
+    if (!trueInt || !falseInt)
+      return failure();
+
+    // Redundant selects are already handled by arith.select canonicalizations.
+    if (trueInt.value() == falseInt.value()) {
+      return failure();
+    }
+
+    // The only remaining possibilities are:
+    //
+    // select(pred, true, false)
+    // select(pred, false, true)
+
+    // select(pred, false, true) -> select(not(pred), true, false)
+    if (trueInt.value() == 0) {
+      Value one = rewriter.create<arith::ConstantIntOp>(
+          selectOp.getLoc(), /*value=*/1, /*width=*/1);
+      pred = rewriter.create<arith::XOrIOp>(selectOp.getLoc(), pred, one);
+    }
+
+    /// select(pred, true, false) -> broadcast(pred)
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+        selectOp, vecType.clone(rewriter.getI1Type()), pred);
+    return success();
+
+    return failure();
+  }
+};
+
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
   // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
   // calling `populateCastAwayVectorLeadingOneDimPatterns`
-  results.add<BroadcastFolder>(context);
+  results.add<BroadcastFolder, FoldI1SelectToBroadcast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0282e9cac5e02..5924e7ea856c4 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1057,6 +1057,38 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)
 
 // -----
 
+// CHECK-LABEL: func.func @canonicalize_i1_select_to_broadcast
+// CHECK-SAME: (%[[PRED:.+]]: i1)
+// CHECK: vector.broadcast %[[PRED]] : i1 to vector<4xi1>
+func.func @canonicalize_i1_select_to_broadcast(%pred: i1) -> vector<4xi1> {
+  %true = arith.constant dense<true> : vector<4x4xi1>
+  %false = arith.constant dense<false> : vector<4x4xi1>
+  %selected = arith.select %pred, %true, %false : vector<4x4xi1>
+  // The select -> broadcast pattern only loads if vector dialect was loaded.
+  // Force loading vector dialect by adding a vector operation.
+  %vec = vector.extract %selected[0] : vector<4xi1> from vector<4x4xi1>
+  return %vec : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @canonicalize_i1_select_to_not_broadcast
+// CHECK-SAME: (%[[PRED:.+]]: i1)
+// CHECK: %[[TRUE:.+]] = arith.constant true
+// CHECK: %[[NOT:.+]] = arith.xori %[[PRED]], %[[TRUE]] : i1
+// CHECK: vector.broadcast %[[NOT]] : i1 to vector<4xi1>
+func.func @canonicalize_i1_select_to_not_broadcast(%pred: i1) -> vector<4xi1> {
+  %true = arith.constant dense<true> : vector<4x4xi1>
+  %false = arith.constant dense<false> : vector<4x4xi1>
+  %selected = arith.select %pred, %false, %true : vector<4x4xi1>
+  // The select -> broadcast pattern only loads if vector dialect was loaded.
+  // Force loading vector dialect by adding a vector operation.
+  %vec = vector.extract %selected[0] : vector<4xi1> from vector<4x4xi1>
+  return %vec : vector<4xi1>
+}
+
+// -----
+
 // CHECK-LABEL: fold_vector_transfer_masks
 func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index

@joker-eph
Copy link
Collaborator

Ouch, we're suffering that the vector type is disconnected from the vector dialect.
Same issue with the tensor type: in general I think that we always have this kind of issue when we decouple types and the operations on this type.

Here the main concern I have is that running a pass pipeline passA, passB would not be the same as running a pipeline with passA followed by a pipeline with passB.
(this may already exists in other cases, but for such canonicalization involving builtin types, it's very visible).

@Groverkss
Copy link
Member Author

Groverkss commented Jul 10, 2025

Ouch, we're suffering that the vector type is disconnected from the vector dialect. Same issue with the tensor type: in general I think that we always have this kind of issue when we decouple types and the operations on this type.

Here the main concern I have is that running a pass pipeline passA, passB would not be the same as running a pipeline with passA followed by a pipeline with passB. (this may already exists in other cases, but for such canonicalization involving builtin types, it's very visible).

I agree... it stems from the fact that we don't have a way to splat a value to a vector in the arith dialect, but we handle things like:

%cond: i1
%a: vector<4xf16>
%b: vector<4xf16>
select %cond, %a, %b

we really have no way of converting a scalar to vector without going through vector dialect. We miss a bunch of canonicalizations that llvm does in InstCombine for selects because of this:

https://github.com/llvm/llvm-project/blob/main/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp#L2351

llvm would convert that select into:

%cond: i1
%a: vector<4xf16>
%b: vector<4xf16>
%splat = splat %cond -> vector<4xi1>
select %splat, %a, %b

and the existing select folders would fold the select as needed.

I'm not sure what the best way of progressing forward is here. Personally, based on the current design of arith dialect and it's inability to splat to a vector, this seems to be most fitting way for now.

@joker-eph joker-eph changed the title [Vector] Add folder for select(pred, true, false) -> broadcast(pred) [Vector] Add canonicalization for select(pred, true, false) -> broadcast(pred) Jul 10, 2025
@Groverkss Groverkss requested a review from kuhar July 10, 2025 14:32
Comment on lines +2972 to +2973

return failure();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dead return

Suggested change
return failure();

} // namespace

void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by
// calling `populateCastAwayVectorLeadingOneDimPatterns`
results.add<BroadcastFolder>(context);
results.add<BroadcastFolder, FoldI1SelectToBroadcast>(context);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have some precedent for canon patterns hooked up to the op they produce?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite counter-intuitive :( But not that uncommon:

void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
}

SwapExtractSliceOfTransferWrite is a canonicalization for xfer_write that matches on tensor.insert_slice.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Adds a canonicalization pattern for vector.broadcast to fold:

Hm, to me this is "rewriting" arith.select as vector.broadcast (as in, intuitively this isn't really folding). Why is vector.broadcast more desirable than arith.select?

Lets discuss this first - you can skip my other comments in the meantime.

Comment on lines +2917 to +2918
/// true: vector
/// false: vector
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// true: vector
/// false: vector
/// true: vector of i1
/// false: vector of i1

///
/// Ideally, this would be a canonicalization pattern on arith::SelectOp, but
/// we cannot have arith depending on vector. Also, it would implicitly force
/// users only using arith and vector dialect to use vector dialect. Instead,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo?

Suggested change
/// users only using arith and vector dialect to use vector dialect. Instead,
/// users only using arith to use vector dialect. Instead,

} // namespace

void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by
// calling `populateCastAwayVectorLeadingOneDimPatterns`
results.add<BroadcastFolder>(context);
results.add<BroadcastFolder, FoldI1SelectToBroadcast>(context);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite counter-intuitive :( But not that uncommon:

void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
}

SwapExtractSliceOfTransferWrite is a canonicalization for xfer_write that matches on tensor.insert_slice.

/// select(pred, true, false) -> broadcast(pred)
/// select(pred, false, true) -> broadcast(not(pred))
///
/// Ideally, this would be a canonicalization pattern on arith::SelectOp, but
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this paragraph makes more sense near where you register this canonicalization.

@Groverkss
Copy link
Member Author

Thanks!

Adds a canonicalization pattern for vector.broadcast to fold:

Hm, to me this is "rewriting" arith.select as vector.broadcast (as in, intuitively this isn't really folding). Why is vector.broadcast more desirable than arith.select?

Lets discuss this first - you can skip my other comments in the meantime.

arith.select evaluates the condition and selects between the true and false vector. broadcasts directly propagates the condition forward. It also unlocks multiple folding opportunities. For example:

%select = arith.select %pred, %true, %false : vector<4xi1>
%broadcasted = vector.broadcast %select : vector<2x4xi1>
%load = vector.maskedload .... /*mask=*/%broadcasted

here, in the current ir, it is quite hard to actually reason if the vector.maskedload mask is either true or false (and not partially true/false), which prevents folding opportunities. On the other hand:

%select = vector.broadcast %pred : vector<4xi1>
%broadcasted = vector.broadcast %select : vector<2x4xi1>
%load = vector.maskedload .... /*mask=*/%broadcasted

will fold into:

%select = vector.broadcast %pred : vector<2x4xi1>
%load = vector.maskedload .... /*mask=*/%broadcasted

and now it is easy to match that the mask is just a broadcasted conditional, so it must be either always full or always empty.

vector.broadcast has very clear semantics that the entire vector has the same value, on the other hand, for select, we actually have to traverse the true/false values to check if the entire vector has the same value.

I will also note that LLVM InstCombine does a similar canonicalization as noted above: https://github.com/llvm/llvm-project/blob/main/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp#L2351

The LLVM canonicalization converts the i1 pred into a broadcasted vector i1 pred, which further just folds away directly, completing the same fold we are doing here.

Personally, for me, this is a very clear canonicalization. vector.broadcast has much more restricted semantics vs arith.select and composes better with other vector operations (unless we plan to write folders for every vector operation that interacts with broadcast, to also work with (select pred, true , false) which is very counter intuitive).

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % nits after the comment from @banach-space showing we have more canon patterns that hook into target ops

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants