Skip to content

Commit 93c7918

Browse files
committed
[MLIR] Canonicalize/fold select %x, 1, 0 to extui
Two canonicalizations for select %x, 1, 0 If the return type is i1, return simply the condition %x, otherwise extui %x to the return type. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D116517
1 parent 834cf3b commit 93c7918

File tree

2 files changed

+76
-1
lines changed

2 files changed

+76
-1
lines changed

mlir/lib/Dialect/StandardOps/IR/Ops.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -840,9 +840,43 @@ struct SelectToNot : public OpRewritePattern<SelectOp> {
840840
}
841841
};
842842

843+
// select %arg, %c1, %c0 => extui %arg
844+
struct SelectToExtUI : public OpRewritePattern<SelectOp> {
845+
using OpRewritePattern<SelectOp>::OpRewritePattern;
846+
847+
LogicalResult matchAndRewrite(SelectOp op,
848+
PatternRewriter &rewriter) const override {
849+
// Cannot extui i1 to i1, or i1 to f32
850+
if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
851+
return failure();
852+
853+
// select %x, c1, %c0 => extui %arg
854+
if (matchPattern(op.getTrueValue(), m_One()))
855+
if (matchPattern(op.getFalseValue(), m_Zero())) {
856+
rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
857+
op.getCondition());
858+
return success();
859+
}
860+
861+
// select %x, c0, %c1 => extui (xor %arg, true)
862+
if (matchPattern(op.getTrueValue(), m_Zero()))
863+
if (matchPattern(op.getFalseValue(), m_One())) {
864+
rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
865+
op, op.getType(),
866+
rewriter.create<arith::XOrIOp>(
867+
op.getLoc(), op.getCondition(),
868+
rewriter.create<arith::ConstantIntOp>(
869+
op.getLoc(), 1, op.getCondition().getType())));
870+
return success();
871+
}
872+
873+
return failure();
874+
}
875+
};
876+
843877
void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
844878
MLIRContext *context) {
845-
results.insert<SelectToNot>(context);
879+
results.insert<SelectToNot, SelectToExtUI>(context);
846880
}
847881

848882
OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
@@ -861,6 +895,12 @@ OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
861895
if (matchPattern(condition, m_Zero()))
862896
return falseVal;
863897

898+
// select %x, true, false => %x
899+
if (getType().isInteger(1))
900+
if (matchPattern(getTrueValue(), m_One()))
901+
if (matchPattern(getFalseValue(), m_Zero()))
902+
return condition;
903+
864904
if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
865905
auto pred = cmp.getPredicate();
866906
if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {

mlir/test/Dialect/Standard/canonicalize.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,41 @@ func @select_cmp_ne_select(%arg0: i64, %arg1: i64) -> i64 {
2929

3030
// -----
3131

32+
// CHECK-LABEL: @select_extui
33+
// CHECK: %[[res:.+]] = arith.extui %arg0 : i1 to i64
34+
// CHECK: return %[[res]]
35+
func @select_extui(%arg0: i1) -> i64 {
36+
%c0_i64 = arith.constant 0 : i64
37+
%c1_i64 = arith.constant 1 : i64
38+
%res = select %arg0, %c1_i64, %c0_i64 : i64
39+
return %res : i64
40+
}
41+
42+
// CHECK-LABEL: @select_extui2
43+
// CHECK-DAG: %true = arith.constant true
44+
// CHECK-DAG: %[[xor:.+]] = arith.xori %arg0, %true : i1
45+
// CHECK-DAG: %[[res:.+]] = arith.extui %[[xor]] : i1 to i64
46+
// CHECK: return %[[res]]
47+
func @select_extui2(%arg0: i1) -> i64 {
48+
%c0_i64 = arith.constant 0 : i64
49+
%c1_i64 = arith.constant 1 : i64
50+
%res = select %arg0, %c0_i64, %c1_i64 : i64
51+
return %res : i64
52+
}
53+
54+
// -----
55+
56+
// CHECK-LABEL: @select_extui_i1
57+
// CHECK-NEXT: return %arg0
58+
func @select_extui_i1(%arg0: i1) -> i1 {
59+
%c0_i1 = arith.constant false
60+
%c1_i1 = arith.constant true
61+
%res = select %arg0, %c1_i1, %c0_i1 : i1
62+
return %res : i1
63+
}
64+
65+
// -----
66+
3267
// CHECK-LABEL: @branchCondProp
3368
// CHECK: %[[trueval:.+]] = arith.constant true
3469
// CHECK: %[[falseval:.+]] = arith.constant false

0 commit comments

Comments
 (0)