From bd6d20d16be7d925cf08c3f292a383f0389d4d8f Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Thu, 10 Jul 2025 11:24:57 +0100 Subject: [PATCH] [Vector] Add folder for select(pred, true, false) -> broadcast(pred) --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 63 +++++++++++++++++++++- mlir/test/Dialect/Vector/canonicalize.mlir | 32 +++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 1fb8c7a928e06..39c8191e8451a 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2913,13 +2913,74 @@ struct BroadcastFolder : public OpRewritePattern { return success(); } }; + +/// true: vector +/// false: vector +/// pred: i1 +/// +/// select(pred, true, false) -> broadcast(pred) +/// select(pred, false, true) -> broadcast(not(pred)) +/// +/// Ideally, this would be a canonicalization pattern on arith::SelectOp, but +/// we cannot have arith depending on vector. Also, it would implicitly force +/// users only using arith and vector dialect to use vector dialect. Instead, +/// this canonicalization only runs if vector::BroadcastOp was a registered +/// operation. +struct FoldI1SelectToBroadcast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::SelectOp selectOp, + PatternRewriter &rewriter) const override { + auto vecType = dyn_cast(selectOp.getType()); + if (!vecType || !vecType.getElementType().isInteger(1)) + return failure(); + + // Vector conditionals do not need broadcast and are already handled by + // the arith.select folder. + Value pred = selectOp.getCondition(); + if (isa(pred.getType())) + return failure(); + + std::optional trueInt = + getConstantIntValue(selectOp.getTrueValue()); + std::optional falseInt = + getConstantIntValue(selectOp.getFalseValue()); + if (!trueInt || !falseInt) + return failure(); + + // Redundant selects are already handled by arith.select canonicalizations. + if (trueInt.value() == falseInt.value()) { + return failure(); + } + + // The only remaining possibilities are: + // + // select(pred, true, false) + // select(pred, false, true) + + // select(pred, false, true) -> select(not(pred), true, false) + if (trueInt.value() == 0) { + Value one = rewriter.create( + selectOp.getLoc(), /*value=*/1, /*width=*/1); + pred = rewriter.create(selectOp.getLoc(), pred, one); + } + + /// select(pred, true, false) -> broadcast(pred) + rewriter.replaceOpWithNewOp( + selectOp, vecType.clone(rewriter.getI1Type()), pred); + return success(); + + return failure(); + } +}; + } // namespace void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { // BroadcastToShapeCast is not a default canonicalization, it is opt-in by // calling `populateCastAwayVectorLeadingOneDimPatterns` - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 0282e9cac5e02..5924e7ea856c4 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1057,6 +1057,38 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) // ----- +// CHECK-LABEL: func.func @canonicalize_i1_select_to_broadcast +// CHECK-SAME: (%[[PRED:.+]]: i1) +// CHECK: vector.broadcast %[[PRED]] : i1 to vector<4xi1> +func.func @canonicalize_i1_select_to_broadcast(%pred: i1) -> vector<4xi1> { + %true = arith.constant dense : vector<4x4xi1> + %false = arith.constant dense : vector<4x4xi1> + %selected = arith.select %pred, %true, %false : vector<4x4xi1> + // The select -> broadcast pattern only loads if vector dialect was loaded. + // Force loading vector dialect by adding a vector operation. + %vec = vector.extract %selected[0] : vector<4xi1> from vector<4x4xi1> + return %vec : vector<4xi1> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_i1_select_to_not_broadcast +// CHECK-SAME: (%[[PRED:.+]]: i1) +// CHECK: %[[TRUE:.+]] = arith.constant true +// CHECK: %[[NOT:.+]] = arith.xori %[[PRED]], %[[TRUE]] : i1 +// CHECK: vector.broadcast %[[NOT]] : i1 to vector<4xi1> +func.func @canonicalize_i1_select_to_not_broadcast(%pred: i1) -> vector<4xi1> { + %true = arith.constant dense : vector<4x4xi1> + %false = arith.constant dense : vector<4x4xi1> + %selected = arith.select %pred, %false, %true : vector<4x4xi1> + // The select -> broadcast pattern only loads if vector dialect was loaded. + // Force loading vector dialect by adding a vector operation. + %vec = vector.extract %selected[0] : vector<4xi1> from vector<4x4xi1> + return %vec : vector<4xi1> +} + +// ----- + // CHECK-LABEL: fold_vector_transfer_masks func.func @fold_vector_transfer_masks(%A: memref) -> (vector<4x8xf32>, vector<4x[4]xf32>) { // CHECK: %[[C0:.+]] = arith.constant 0 : index