Skip to content

Commit 834cf3b

Browse files
committed
[MLIR][Arith] Canonicalize and/or with ext
Replace and(ext(a),ext(b)) with ext(and(a,b)). This both reduces one instruction, and results in the computation (and/or) being done on a smaller type. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D116519
1 parent 78389de commit 834cf3b

File tree

4 files changed

+94
-0
lines changed

4 files changed

+94
-0
lines changed

mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ def Arith_AndIOp : Arith_IntBinaryOp<"andi", [Commutative, Idempotent]> {
437437
```
438438
}];
439439
let hasFolder = 1;
440+
let hasCanonicalizer = 1;
440441
}
441442

442443
//===----------------------------------------------------------------------===//
@@ -465,6 +466,7 @@ def Arith_OrIOp : Arith_IntBinaryOp<"ori", [Commutative, Idempotent]> {
465466
```
466467
}];
467468
let hasFolder = 1;
469+
let hasCanonicalizer = 1;
468470
}
469471

470472
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,32 @@ def BitcastOfBitcast :
136136
def ExtSIOfExtUI :
137137
Pat<(Arith_ExtSIOp (Arith_ExtUIOp $x)), (Arith_ExtUIOp $x)>;
138138

139+
//===----------------------------------------------------------------------===//
140+
// AndIOp
141+
//===----------------------------------------------------------------------===//
142+
143+
// and extui(x), extui(y) -> extui(and(x,y))
144+
def AndOfExtUI :
145+
Pat<(Arith_AndIOp (Arith_ExtUIOp $x), (Arith_ExtUIOp $y)), (Arith_ExtUIOp (Arith_AndIOp $x, $y)),
146+
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
147+
148+
// and extsi(x), extsi(y) -> extsi(and(x,y))
149+
def AndOfExtSI :
150+
Pat<(Arith_AndIOp (Arith_ExtSIOp $x), (Arith_ExtSIOp $y)), (Arith_ExtSIOp (Arith_AndIOp $x, $y)),
151+
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
152+
153+
//===----------------------------------------------------------------------===//
154+
// OrIOp
155+
//===----------------------------------------------------------------------===//
156+
157+
// or extui(x), extui(y) -> extui(or(x,y))
158+
def OrOfExtUI :
159+
Pat<(Arith_OrIOp (Arith_ExtUIOp $x), (Arith_ExtUIOp $y)), (Arith_ExtUIOp (Arith_OrIOp $x, $y)),
160+
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
161+
162+
// or extsi(x), extsi(y) -> extsi(or(x,y))
163+
def OrOfExtSI :
164+
Pat<(Arith_OrIOp (Arith_ExtSIOp $x), (Arith_ExtSIOp $y)), (Arith_ExtSIOp (Arith_OrIOp $x, $y)),
165+
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
166+
139167
#endif // ARITHMETIC_PATTERNS

mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,24 @@ bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
901901
return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
902902
}
903903

904+
//===----------------------------------------------------------------------===//
905+
// AndIOp
906+
//===----------------------------------------------------------------------===//
907+
908+
void arith::AndIOp::getCanonicalizationPatterns(
909+
OwningRewritePatternList &patterns, MLIRContext *context) {
910+
patterns.insert<AndOfExtUI, AndOfExtSI>(context);
911+
}
912+
913+
//===----------------------------------------------------------------------===//
914+
// OrIOp
915+
//===----------------------------------------------------------------------===//
916+
917+
void arith::OrIOp::getCanonicalizationPatterns(
918+
OwningRewritePatternList &patterns, MLIRContext *context) {
919+
patterns.insert<OrOfExtUI, OrOfExtSI>(context);
920+
}
921+
904922
//===----------------------------------------------------------------------===//
905923
// Verifiers for casts between integers and floats.
906924
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Arithmetic/canonicalize.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,52 @@ func @extSIOfExtSI(%arg0: i1) -> i64 {
9999

100100
// -----
101101

102+
// CHECK-LABEL: @andOfExtSI
103+
// CHECK: %[[comb:.+]] = arith.andi %arg0, %arg1 : i8
104+
// CHECK: %[[ext:.+]] = arith.extsi %[[comb]] : i8 to i64
105+
// CHECK: return %[[ext]]
106+
func @andOfExtSI(%arg0: i8, %arg1: i8) -> i64 {
107+
%ext0 = arith.extsi %arg0 : i8 to i64
108+
%ext1 = arith.extsi %arg1 : i8 to i64
109+
%res = arith.andi %ext0, %ext1 : i64
110+
return %res : i64
111+
}
112+
113+
// CHECK-LABEL: @andOfExtUI
114+
// CHECK: %[[comb:.+]] = arith.andi %arg0, %arg1 : i8
115+
// CHECK: %[[ext:.+]] = arith.extui %[[comb]] : i8 to i64
116+
// CHECK: return %[[ext]]
117+
func @andOfExtUI(%arg0: i8, %arg1: i8) -> i64 {
118+
%ext0 = arith.extui %arg0 : i8 to i64
119+
%ext1 = arith.extui %arg1 : i8 to i64
120+
%res = arith.andi %ext0, %ext1 : i64
121+
return %res : i64
122+
}
123+
124+
// CHECK-LABEL: @orOfExtSI
125+
// CHECK: %[[comb:.+]] = arith.ori %arg0, %arg1 : i8
126+
// CHECK: %[[ext:.+]] = arith.extsi %[[comb]] : i8 to i64
127+
// CHECK: return %[[ext]]
128+
func @orOfExtSI(%arg0: i8, %arg1: i8) -> i64 {
129+
%ext0 = arith.extsi %arg0 : i8 to i64
130+
%ext1 = arith.extsi %arg1 : i8 to i64
131+
%res = arith.ori %ext0, %ext1 : i64
132+
return %res : i64
133+
}
134+
135+
// CHECK-LABEL: @orOfExtUI
136+
// CHECK: %[[comb:.+]] = arith.ori %arg0, %arg1 : i8
137+
// CHECK: %[[ext:.+]] = arith.extui %[[comb]] : i8 to i64
138+
// CHECK: return %[[ext]]
139+
func @orOfExtUI(%arg0: i8, %arg1: i8) -> i64 {
140+
%ext0 = arith.extui %arg0 : i8 to i64
141+
%ext1 = arith.extui %arg1 : i8 to i64
142+
%res = arith.ori %ext0, %ext1 : i64
143+
return %res : i64
144+
}
145+
146+
// -----
147+
102148
// CHECK-LABEL: @indexCastOfSignExtend
103149
// CHECK: %[[res:.+]] = arith.index_cast %arg0 : i8 to index
104150
// CHECK: return %[[res]]

0 commit comments

Comments
 (0)