@@ -840,9 +840,43 @@ struct SelectToNot : public OpRewritePattern<SelectOp> {
840
840
}
841
841
};
842
842
843
+ // select %arg, %c1, %c0 => extui %arg
844
+ struct SelectToExtUI : public OpRewritePattern <SelectOp> {
845
+ using OpRewritePattern<SelectOp>::OpRewritePattern;
846
+
847
+ LogicalResult matchAndRewrite (SelectOp op,
848
+ PatternRewriter &rewriter) const override {
849
+ // Cannot extui i1 to i1, or i1 to f32
850
+ if (!op.getType ().isa <IntegerType>() || op.getType ().isInteger (1 ))
851
+ return failure ();
852
+
853
+ // select %x, c1, %c0 => extui %arg
854
+ if (matchPattern (op.getTrueValue (), m_One ()))
855
+ if (matchPattern (op.getFalseValue (), m_Zero ())) {
856
+ rewriter.replaceOpWithNewOp <arith::ExtUIOp>(op, op.getType (),
857
+ op.getCondition ());
858
+ return success ();
859
+ }
860
+
861
+ // select %x, c0, %c1 => extui (xor %arg, true)
862
+ if (matchPattern (op.getTrueValue (), m_Zero ()))
863
+ if (matchPattern (op.getFalseValue (), m_One ())) {
864
+ rewriter.replaceOpWithNewOp <arith::ExtUIOp>(
865
+ op, op.getType (),
866
+ rewriter.create <arith::XOrIOp>(
867
+ op.getLoc (), op.getCondition (),
868
+ rewriter.create <arith::ConstantIntOp>(
869
+ op.getLoc (), 1 , op.getCondition ().getType ())));
870
+ return success ();
871
+ }
872
+
873
+ return failure ();
874
+ }
875
+ };
876
+
843
877
void SelectOp::getCanonicalizationPatterns (OwningRewritePatternList &results,
844
878
MLIRContext *context) {
845
- results.insert <SelectToNot>(context);
879
+ results.insert <SelectToNot, SelectToExtUI >(context);
846
880
}
847
881
848
882
OpFoldResult SelectOp::fold (ArrayRef<Attribute> operands) {
@@ -861,6 +895,12 @@ OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
861
895
if (matchPattern (condition, m_Zero ()))
862
896
return falseVal;
863
897
898
+ // select %x, true, false => %x
899
+ if (getType ().isInteger (1 ))
900
+ if (matchPattern (getTrueValue (), m_One ()))
901
+ if (matchPattern (getFalseValue (), m_Zero ()))
902
+ return condition;
903
+
864
904
if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp ())) {
865
905
auto pred = cmp.getPredicate ();
866
906
if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
0 commit comments