Skip to content

Commit 56e9d95

Browse files
committed
[RFC][mlir] Conditional support for fast-math attributes.
This patch suggests changes for operations that support arith::ArithFastMathInterface/LLVM::FastmathFlagsInterface. Some of the operations may have fast-math flags not equal to `none` only if they operate on floating point values. This is inspired by https://llvm.org/docs/LangRef.html#fastmath-return-types and my goal to add fast-math support for `arith.select` operation that may produce results of any type. The changes add new isArithFastMathApplicable/isFastmathApplicable methods to the above interfaces that tell whether an operation supporting the interface may have non-none fast-math flags. LLVM dialect isFastmathApplicable implementation is based on https://github.com/llvm/llvm-project/blob/bac62ee5b473e70981a6bd9759ec316315fca07d/llvm/include/llvm/IR/Operator.h#L380 ARITH dialect isArithFastMathApplicable is more relaxed, because it has to support custom MLIR types. This is the area where improvements are needed (see TODO comments). I will appreciate feedback here. HLFIR dialect is a another example where conditional fast-math support may be applied currently.
1 parent 0572580 commit 56e9d95

File tree

17 files changed

+353
-68
lines changed

17 files changed

+353
-68
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2494,6 +2494,21 @@ def fir_CallOp : fir_Op<"call",
24942494
llvm::cast<mlir::SymbolRefAttr>(callee));
24952495
setOperand(0, llvm::cast<mlir::Value>(callee));
24962496
}
2497+
2498+
/// Always allow FastMathFlags for fir.call's.
2499+
/// It is required to be able to propagate the call site's
2500+
/// FastMathFlags to the operations resulting from inlining
2501+
/// (if any) of a fir.call (see SimplifyIntrinsics pass).
2502+
/// We could analyze the arguments' data types to see if there are
2503+
/// any floating point types, but this is unreliable. For example,
2504+
/// the runtime calls mostly take !fir.box<none> arguments,
2505+
/// and tracking them to the definitions may be not easy.
2506+
/// TODO: this should be restricted to fir.runtime calls,
2507+
/// because FastMathFlags for the user calls must come
2508+
/// from the function body, not the call site.
2509+
bool isArithFastMathApplicable() {
2510+
return true;
2511+
}
24972512
}];
24982513
}
24992514

@@ -2672,6 +2687,15 @@ def fir_CmpcOp : fir_Op<"cmpc",
26722687
}
26732688

26742689
static mlir::arith::CmpFPredicate getPredicateByName(llvm::StringRef name);
2690+
2691+
/// Always allow FastMathFlags on fir.cmpc.
2692+
/// It does not produce a floating point result, but
2693+
/// LLVM is currently relying on fast-math flags attached
2694+
/// to floating point comparison.
2695+
/// This can be removed whenever LLVM stops doing it.
2696+
bool isArithFastMathApplicable() {
2697+
return true;
2698+
}
26752699
}];
26762700
}
26772701

@@ -2735,6 +2759,8 @@ def fir_ConvertOp : fir_SimpleOneResultOp<"convert", [NoMemoryEffect]> {
27352759
static bool isPointerCompatible(mlir::Type ty);
27362760
static bool canBeConverted(mlir::Type inType, mlir::Type outType);
27372761
static bool areVectorsCompatible(mlir::Type inTy, mlir::Type outTy);
2762+
2763+
// FIXME: fir.convert should support ArithFastMathInterface.
27382764
}];
27392765
let hasCanonicalizer = 1;
27402766
}

flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,11 @@ bool mayHaveAllocatableComponent(mlir::Type ty);
139139
/// Scalar integer or a sequence of integers (via boxed array or expr).
140140
bool isFortranIntegerScalarOrArrayObject(mlir::Type type);
141141

142+
/// Return true iff FastMathFlagsAttr is applicable
143+
/// to the given HLFIR dialect operation that supports
144+
/// ArithFastMathInterface.
145+
bool isArithFastMathApplicable(mlir::Operation *op);
146+
142147
} // namespace hlfir
143148

144149
#endif // FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H

flang/include/flang/Optimizer/HLFIR/HLFIROps.td

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,12 @@ def hlfir_MaxvalOp : hlfir_Op<"maxval", [AttrSizedOperandSegments,
434434
}];
435435

436436
let hasVerifier = 1;
437+
438+
let extraClassDeclaration = [{
439+
bool isArithFastMathApplicable() {
440+
return hlfir::isArithFastMathApplicable(getOperation());
441+
}
442+
}];
437443
}
438444

439445
def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
@@ -461,6 +467,12 @@ def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
461467
}];
462468

463469
let hasVerifier = 1;
470+
471+
let extraClassDeclaration = [{
472+
bool isArithFastMathApplicable() {
473+
return hlfir::isArithFastMathApplicable(getOperation());
474+
}
475+
}];
464476
}
465477

466478
def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
@@ -487,6 +499,12 @@ def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
487499
}];
488500

489501
let hasVerifier = 1;
502+
503+
let extraClassDeclaration = [{
504+
bool isArithFastMathApplicable() {
505+
return hlfir::isArithFastMathApplicable(getOperation());
506+
}
507+
}];
490508
}
491509

492510
def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
@@ -513,6 +531,12 @@ def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
513531
}];
514532

515533
let hasVerifier = 1;
534+
535+
let extraClassDeclaration = [{
536+
bool isArithFastMathApplicable() {
537+
return hlfir::isArithFastMathApplicable(getOperation());
538+
}
539+
}];
516540
}
517541

518542
def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
@@ -539,6 +563,12 @@ def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
539563
}];
540564

541565
let hasVerifier = 1;
566+
567+
let extraClassDeclaration = [{
568+
bool isArithFastMathApplicable() {
569+
return hlfir::isArithFastMathApplicable(getOperation());
570+
}
571+
}];
542572
}
543573

544574
def hlfir_SetLengthOp : hlfir_Op<"set_length",
@@ -604,6 +634,12 @@ def hlfir_SumOp : hlfir_Op<"sum", [AttrSizedOperandSegments,
604634
}];
605635

606636
let hasVerifier = 1;
637+
638+
let extraClassDeclaration = [{
639+
bool isArithFastMathApplicable() {
640+
return hlfir::isArithFastMathApplicable(getOperation());
641+
}
642+
}];
607643
}
608644

609645
def hlfir_DotProductOp : hlfir_Op<"dot_product",
@@ -628,6 +664,12 @@ def hlfir_DotProductOp : hlfir_Op<"dot_product",
628664
}];
629665

630666
let hasVerifier = 1;
667+
668+
let extraClassDeclaration = [{
669+
bool isArithFastMathApplicable() {
670+
return hlfir::isArithFastMathApplicable(getOperation());
671+
}
672+
}];
631673
}
632674

633675
def hlfir_MatmulOp : hlfir_Op<"matmul",
@@ -655,6 +697,12 @@ def hlfir_MatmulOp : hlfir_Op<"matmul",
655697
let hasCanonicalizeMethod = 1;
656698

657699
let hasVerifier = 1;
700+
701+
let extraClassDeclaration = [{
702+
bool isArithFastMathApplicable() {
703+
return hlfir::isArithFastMathApplicable(getOperation());
704+
}
705+
}];
658706
}
659707

660708
def hlfir_TransposeOp : hlfir_Op<"transpose",
@@ -697,6 +745,12 @@ def hlfir_MatmulTransposeOp : hlfir_Op<"matmul_transpose",
697745
}];
698746

699747
let hasVerifier = 1;
748+
749+
let extraClassDeclaration = [{
750+
bool isArithFastMathApplicable() {
751+
return hlfir::isArithFastMathApplicable(getOperation());
752+
}
753+
}];
700754
}
701755

702756
def hlfir_CShiftOp

flang/lib/Optimizer/Builder/FIRBuilder.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -786,9 +786,7 @@ mlir::Value fir::FirOpBuilder::genAbsentOp(mlir::Location loc,
786786

787787
void fir::FirOpBuilder::setCommonAttributes(mlir::Operation *op) const {
788788
auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
789-
if (fmi) {
790-
// TODO: use fmi.setFastMathFlagsAttr() after D137114 is merged.
791-
// For now set the attribute by the name.
789+
if (fmi && fmi.isArithFastMathApplicable()) {
792790
llvm::StringRef arithFMFAttrName = fmi.getFastMathAttrName();
793791
if (fastMathFlags != mlir::arith::FastMathFlags::none)
794792
op->setAttr(arithFMFAttrName, mlir::arith::FastMathFlagsAttr::get(

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -589,10 +589,18 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
589589
// Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr.
590590
mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
591591
attrConvert(call);
592-
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
593-
call, resultTys, adaptor.getOperands(),
592+
auto llvmCall = rewriter.create<mlir::LLVM::CallOp>(
593+
call.getLoc(), resultTys, adaptor.getOperands(),
594594
addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
595595
adaptor.getOperands().size()));
596+
auto fmi =
597+
mlir::cast<mlir::LLVM::FastmathFlagsInterface>(llvmCall.getOperation());
598+
if (!fmi.isFastmathApplicable())
599+
llvmCall->setAttr(
600+
mlir::LLVM::CallOp::getFastmathAttrName(),
601+
mlir::LLVM::FastmathFlagsAttr::get(call.getContext(),
602+
mlir::LLVM::FastmathFlags::none));
603+
rewriter.replaceOp(call, llvmCall);
596604
return mlir::success();
597605
}
598606
};

flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,3 +237,20 @@ bool hlfir::isFortranIntegerScalarOrArrayObject(mlir::Type type) {
237237
mlir::Type elementType = getFortranElementType(unwrappedType);
238238
return mlir::isa<mlir::IntegerType>(elementType);
239239
}
240+
241+
bool hlfir::isArithFastMathApplicable(mlir::Operation *op) {
242+
if (llvm::any_of(op->getResults(), [](mlir::Value v) {
243+
mlir::Type elementType = getFortranElementType(v.getType());
244+
return mlir::arith::ArithFastMathInterface::isCompatibleType(
245+
elementType);
246+
}))
247+
return true;
248+
if (llvm::any_of(op->getOperands(), [](mlir::Value v) {
249+
mlir::Type elementType = getFortranElementType(v.getType());
250+
return mlir::arith::ArithFastMathInterface::isCompatibleType(
251+
elementType);
252+
}))
253+
return true;
254+
255+
return true;
256+
}

flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<i1, dense<8> : ve
5656
%45 = llvm.call @_FortranACUFDataTransferPtrPtr(%14, %25, %2, %11, %13, %5) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
5757
gpu.launch_func @cuda_device_mod::@_QMmod1Psub1 blocks in (%7, %7, %7) threads in (%12, %7, %7) : i64 dynamic_shared_memory_size %11 args(%14 : !llvm.ptr)
5858
%46 = llvm.call @_FortranACUFDataTransferPtrPtr(%25, %14, %2, %10, %13, %4) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
59-
%47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
59+
%47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) : (i32, !llvm.ptr, i32) -> !llvm.ptr
6060
%48 = llvm.mlir.constant(9 : i32) : i32
6161
%49 = llvm.mlir.zero : !llvm.ptr
6262
%50 = llvm.getelementptr %49[1] : (!llvm.ptr) -> !llvm.ptr, i32

flang/test/Fir/tbaa.fir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ module {
136136
// CHECK: %[[VAL_6:.*]] = llvm.mlir.constant(-1 : i32) : i32
137137
// CHECK: %[[VAL_7:.*]] = llvm.mlir.addressof @_QFEx : !llvm.ptr
138138
// CHECK: %[[VAL_8:.*]] = llvm.mlir.addressof @_QQclX2E2F64756D6D792E66393000 : !llvm.ptr
139-
// CHECK: %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
139+
// CHECK: %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) : (i32, !llvm.ptr, i32) -> !llvm.ptr
140140
// CHECK: %[[VAL_11:.*]] = llvm.mlir.constant(64 : i32) : i32
141141
// CHECK: "llvm.intr.memcpy"(%[[VAL_3]], %[[VAL_7]], %[[VAL_11]]) <{isVolatile = false, tbaa = [#[[$BOXT]]]}>
142142
// CHECK: %[[VAL_12:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
@@ -188,8 +188,8 @@ module {
188188
// CHECK: %[[VAL_59:.*]] = llvm.insertvalue %[[VAL_50]], %[[VAL_58]][7, 0, 2] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
189189
// CHECK: %[[VAL_61:.*]] = llvm.insertvalue %[[VAL_52]], %[[VAL_59]][0] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
190190
// CHECK: llvm.store %[[VAL_61]], %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>, !llvm.ptr
191-
// CHECK: %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr, !llvm.ptr) -> i1
192-
// CHECK: %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr) -> i32
191+
// CHECK: %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) : (!llvm.ptr, !llvm.ptr) -> i1
192+
// CHECK: %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) : (!llvm.ptr) -> i32
193193
// CHECK: llvm.return
194194
// CHECK: }
195195
// CHECK: llvm.func @_FortranAioBeginExternalListOutput(i32, !llvm.ptr, i32) -> !llvm.ptr attributes {fir.io, fir.runtime, sym_visibility = "private"}

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,6 +1211,9 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
12111211
The destination type must to be strictly wider than the source type.
12121212
When operating on vectors, casts elementwise.
12131213
}];
1214+
let extraClassDeclaration = [{
1215+
bool isApplicable() { return true; }
1216+
}];
12141217
let hasVerifier = 1;
12151218
let hasFolder = 1;
12161219

@@ -1545,6 +1548,17 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf",
15451548
let hasCanonicalizer = 1;
15461549
let assemblyFormat = [{ $predicate `,` $lhs `,` $rhs (`fastmath` `` $fastmath^)?
15471550
attr-dict `:` type($lhs)}];
1551+
1552+
let extraClassDeclaration = [{
1553+
/// Always allow FastMathFlags on arith.cmpf.
1554+
/// It does not produce a floating point result, but
1555+
/// LLVM is currently relying on fast-math flags attached
1556+
/// to floating point comparison.
1557+
/// This can be removed whenever LLVM stops doing it.
1558+
bool isArithFastMathApplicable() {
1559+
return true;
1560+
}
1561+
}];
15481562
}
15491563

15501564
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td

Lines changed: 49 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,31 +22,60 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
2222

2323
let cppNamespace = "::mlir::arith";
2424

25-
let methods = [
26-
InterfaceMethod<
27-
/*desc=*/ "Returns a FastMathFlagsAttr attribute for the operation",
28-
/*returnType=*/ "FastMathFlagsAttr",
29-
/*methodName=*/ "getFastMathFlagsAttr",
30-
/*args=*/ (ins),
31-
/*methodBody=*/ [{}],
32-
/*defaultImpl=*/ [{
25+
let methods =
26+
[InterfaceMethod<
27+
/*desc=*/"Returns a FastMathFlagsAttr attribute for the operation",
28+
/*returnType=*/"FastMathFlagsAttr",
29+
/*methodName=*/"getFastMathFlagsAttr",
30+
/*args=*/(ins),
31+
/*methodBody=*/[{}],
32+
/*defaultImpl=*/[{
3333
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
3434
return op.getFastmathAttr();
35-
}]
36-
>,
37-
StaticInterfaceMethod<
38-
/*desc=*/ [{Returns the name of the FastMathFlagsAttr attribute
35+
}]>,
36+
StaticInterfaceMethod<
37+
/*desc=*/[{Returns the name of the FastMathFlagsAttr attribute
3938
for the operation}],
40-
/*returnType=*/ "StringRef",
41-
/*methodName=*/ "getFastMathAttrName",
42-
/*args=*/ (ins),
43-
/*methodBody=*/ [{}],
44-
/*defaultImpl=*/ [{
39+
/*returnType=*/"StringRef",
40+
/*methodName=*/"getFastMathAttrName",
41+
/*args=*/(ins),
42+
/*methodBody=*/[{}],
43+
/*defaultImpl=*/[{
4544
return "fastmath";
46-
}]
47-
>
45+
}]>,
46+
InterfaceMethod<
47+
/*desc=*/[{Returns true iff FastMathFlagsAttr attribute
48+
is applicable to the operation that supports
49+
ArithFastMathInterface. If it returns false,
50+
then the FastMathFlagsAttr of the operation
51+
must be nullptr or have 'none' value}],
52+
/*returnType=*/"bool",
53+
/*methodName=*/"isArithFastMathApplicable",
54+
/*args=*/(ins),
55+
/*methodBody=*/[{}],
56+
/*defaultImpl=*/[{
57+
return ::mlir::cast<::mlir::arith::ArithFastMathInterface>(this->getOperation()).isApplicableImpl();
58+
}]>];
4859

49-
];
60+
let extraClassDeclaration = [{
61+
/// Returns true iff the given type is a floating point type
62+
/// or contains one.
63+
static bool isCompatibleType(::mlir::Type);
64+
65+
/// Default implementation of isArithFastMathApplicable().
66+
/// It returns true iff any of the results of the operations
67+
/// has a type that is compatible with fast-math.
68+
bool isApplicableImpl();
69+
}];
70+
71+
let verify = [{
72+
auto fmi = ::mlir::cast<::mlir::arith::ArithFastMathInterface>($_op);
73+
auto attr = fmi.getFastMathFlagsAttr();
74+
if (attr && attr.getValue() != ::mlir::arith::FastMathFlags::none &&
75+
!fmi.isArithFastMathApplicable())
76+
return $_op->emitOpError() << "FastMathFlagsAttr is not applicable";
77+
return ::mlir::success();
78+
}];
5079
}
5180

5281
def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {

0 commit comments

Comments
 (0)