Skip to content

Commit b33a131

Browse files
authored
[mlir][arith] Add support for expanding arith.maxnumf/minnumf ops. (#75989)
The maxnum/minnum semantics can be found at https://llvm.org/docs/LangRef.html#llvm-minnum-intrinsic. The revision also updates function names in lit tests to match op name. Take arith.maxnumf as example: ``` func.func @maxnumf(%lhs: f32, %rhs: f32) -> f32 { %result = arith.maxnumf %lhs, %rhs : f32 return %result : f32 } ``` will be expanded to ``` func.func @maxnumf(%lhs: f32, %rhs: f32) -> f32 { %0 = arith.cmpf ugt, %lhs, %rhs : f32 %1 = arith.select %0, %lhs, %rhs : f32 %2 = arith.cmpf uno, %lhs, %lhs : f32 %3 = arith.select %2, %rhs, %1 : f32 return %3 : f32 } ``` Case 1: Both LHS and RHS are not NaN; LHS > RHS In this case, `%1` is LHS. `%3` and `%1` have the same value, so `%3` is LHS. Case 2: LHS is NaN and RHS is not NaN In this case, `%2` is true, so `%3` is always RHS. Case 3: LHS is not NaN and RHS is NaN In this case, `%0` is true and `%1` is LHS. `%2` is false, so `%3` and `%1` have the same value, which is LHS. Case 4: Both LHS and RHS are NaN: `%1` and RHS are all NaN, so the result is still NaN.
1 parent e7bd673 commit b33a131

File tree

2 files changed

+68
-8
lines changed

2 files changed

+68
-8
lines changed

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,32 @@ struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
186186
}
187187
};
188188

189+
template <typename OpTy, arith::CmpFPredicate pred>
190+
struct MaxNumMinNumFOpConverter : public OpRewritePattern<OpTy> {
191+
public:
192+
using OpRewritePattern<OpTy>::OpRewritePattern;
193+
194+
LogicalResult matchAndRewrite(OpTy op,
195+
PatternRewriter &rewriter) const final {
196+
Value lhs = op.getLhs();
197+
Value rhs = op.getRhs();
198+
199+
Location loc = op.getLoc();
200+
// If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
201+
static_assert(pred == arith::CmpFPredicate::UGT ||
202+
pred == arith::CmpFPredicate::ULT,
203+
"pred must be either UGT or ULT");
204+
Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
205+
Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
206+
207+
// Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'.
208+
Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
209+
lhs, lhs);
210+
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
211+
return success();
212+
}
213+
};
214+
189215
struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
190216
using OpRewritePattern::OpRewritePattern;
191217
LogicalResult matchAndRewrite(arith::ExtFOp op,
@@ -319,7 +345,9 @@ struct ArithExpandOpsPass
319345
arith::CeilDivUIOp,
320346
arith::FloorDivSIOp,
321347
arith::MaximumFOp,
322-
arith::MinimumFOp
348+
arith::MinimumFOp,
349+
arith::MaxNumFOp,
350+
arith::MinNumFOp
323351
>();
324352

325353
if (includeBf16) {
@@ -365,7 +393,9 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
365393
// clang-format off
366394
patterns.add<
367395
MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
368-
MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>
396+
MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
397+
MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
398+
MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
369399
>(patterns.getContext());
370400
// clang-format on
371401
}

mlir/test/Dialect/Arith/expand-ops.mlir

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ func.func @ceildivui_index(%arg0: index, %arg1: index) -> (index) {
176176

177177
// -----
178178

179-
// CHECK-LABEL: func @maxf
180-
func.func @maxf(%a: f32, %b: f32) -> f32 {
179+
// CHECK-LABEL: func @maximumf
180+
func.func @maximumf(%a: f32, %b: f32) -> f32 {
181181
%result = arith.maximumf %a, %b : f32
182182
return %result : f32
183183
}
@@ -190,8 +190,8 @@ func.func @maxf(%a: f32, %b: f32) -> f32 {
190190

191191
// -----
192192

193-
// CHECK-LABEL: func @maxf_vector
194-
func.func @maxf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> {
193+
// CHECK-LABEL: func @maximumf_vector
194+
func.func @maximumf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> {
195195
%result = arith.maximumf %a, %b : vector<4xf16>
196196
return %result : vector<4xf16>
197197
}
@@ -204,8 +204,23 @@ func.func @maxf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> {
204204

205205
// -----
206206

207-
// CHECK-LABEL: func @minf
208-
func.func @minf(%a: f32, %b: f32) -> f32 {
207+
// CHECK-LABEL: func @maxnumf
208+
func.func @maxnumf(%a: f32, %b: f32) -> f32 {
209+
%result = arith.maxnumf %a, %b : f32
210+
return %result : f32
211+
}
212+
213+
// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
214+
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : f32
215+
// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : f32
216+
// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[LHS]] : f32
217+
// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32
218+
// CHECK-NEXT: return %[[RESULT]] : f32
219+
220+
// -----
221+
222+
// CHECK-LABEL: func @minimumf
223+
func.func @minimumf(%a: f32, %b: f32) -> f32 {
209224
%result = arith.minimumf %a, %b : f32
210225
return %result : f32
211226
}
@@ -219,6 +234,21 @@ func.func @minf(%a: f32, %b: f32) -> f32 {
219234

220235
// -----
221236

237+
// CHECK-LABEL: func @minnumf
238+
func.func @minnumf(%a: f32, %b: f32) -> f32 {
239+
%result = arith.minnumf %a, %b : f32
240+
return %result : f32
241+
}
242+
243+
// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
244+
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ult, %[[LHS]], %[[RHS]] : f32
245+
// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : f32
246+
// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[LHS]] : f32
247+
// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32
248+
// CHECK-NEXT: return %[[RESULT]] : f32
249+
250+
// -----
251+
222252
func.func @truncf_f32(%arg0 : f32) -> bf16 {
223253
%0 = arith.truncf %arg0 : f32 to bf16
224254
return %0 : bf16

0 commit comments

Comments
 (0)