Skip to content

Commit dfc5125

Browse files
authored
[NVPTX] Consistently check fast-math flags when lowering fsqrt (#143776)
Ensure that we check the global, function-level, and instruction-level flags when considering whether to use `sqrt.rn` or `sqrt.approx` to lower either `@llvm.sqrt.f32` or `@llvm.nvvm.sqrt.f`
1 parent 054f4a5 commit dfc5125

File tree

8 files changed

+695
-182
lines changed

8 files changed

+695
-182
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ NVPTXDAGToDAGISel::getDivF32Level(const SDNode *N) const {
7171
return Subtarget->getTargetLowering()->getDivF32Level(*MF, *N);
7272
}
7373

74-
bool NVPTXDAGToDAGISel::usePrecSqrtF32() const {
75-
return Subtarget->getTargetLowering()->usePrecSqrtF32();
74+
bool NVPTXDAGToDAGISel::usePrecSqrtF32(const SDNode *N) const {
75+
return Subtarget->getTargetLowering()->usePrecSqrtF32(*MF, N);
7676
}
7777

7878
bool NVPTXDAGToDAGISel::useF32FTZ() const {

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
4444
bool doMulWide;
4545

4646
NVPTX::DivPrecisionLevel getDivF32Level(const SDNode *N) const;
47-
bool usePrecSqrtF32() const;
47+
bool usePrecSqrtF32(const SDNode *N) const;
4848
bool useF32FTZ() const;
4949
bool allowFMA() const;
5050
bool allowUnsafeFPMath() const;

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,23 @@ NVPTXTargetLowering::getDivF32Level(const MachineFunction &MF,
134134
return NVPTX::DivPrecisionLevel::IEEE754;
135135
}
136136

137-
bool NVPTXTargetLowering::usePrecSqrtF32() const {
138-
if (UsePrecSqrtF32.getNumOccurrences() > 0) {
139-
// If nvptx-prec-sqrtf32 is used on the command-line, always honor it
137+
bool NVPTXTargetLowering::usePrecSqrtF32(const MachineFunction &MF,
138+
const SDNode *N) const {
139+
// If nvptx-prec-sqrtf32 is used on the command-line, always honor it
140+
if (UsePrecSqrtF32.getNumOccurrences() > 0)
140141
return UsePrecSqrtF32;
141-
} else {
142-
// Otherwise, use sqrt.approx if fast math is enabled
143-
return !getTargetMachine().Options.UnsafeFPMath;
142+
143+
// Otherwise, use sqrt.approx if fast math is enabled
144+
if (allowUnsafeFPMath(MF))
145+
return false;
146+
147+
if (N) {
148+
const SDNodeFlags Flags = N->getFlags();
149+
if (Flags.hasApproximateFuncs())
150+
return false;
144151
}
152+
153+
return true;
145154
}
146155

147156
bool NVPTXTargetLowering::useF32FTZ(const MachineFunction &MF) const {
@@ -1134,7 +1143,8 @@ SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
11341143
bool &UseOneConst,
11351144
bool Reciprocal) const {
11361145
if (!(Enabled == ReciprocalEstimate::Enabled ||
1137-
(Enabled == ReciprocalEstimate::Unspecified && !usePrecSqrtF32())))
1146+
(Enabled == ReciprocalEstimate::Unspecified &&
1147+
!usePrecSqrtF32(DAG.getMachineFunction()))))
11381148
return SDValue();
11391149

11401150
if (ExtraSteps == ReciprocalEstimate::Unspecified)

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,8 @@ class NVPTXTargetLowering : public TargetLowering {
225225

226226
// Get whether we should use a precise or approximate 32-bit floating point
227227
// sqrt instruction.
228-
bool usePrecSqrtF32() const;
228+
bool usePrecSqrtF32(const MachineFunction &MF,
229+
const SDNode *N = nullptr) const;
229230

230231
// Get whether we should use instructions that flush floating-point denormals
231232
// to sign-preserving zero.

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,6 @@ def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
151151

152152
def doMulWide : Predicate<"doMulWide">;
153153

154-
def do_SQRTF32_APPROX : Predicate<"!usePrecSqrtF32()">;
155-
def do_SQRTF32_RN : Predicate<"usePrecSqrtF32()">;
156-
157154
def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
158155
def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
159156
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1520,15 +1520,18 @@ def INT_NVVM_SQRT_RZ_D : F_MATH_1<"sqrt.rz.f64", F64RT, F64RT, int_nvvm_sqrt_rz_
15201520
def INT_NVVM_SQRT_RM_D : F_MATH_1<"sqrt.rm.f64", F64RT, F64RT, int_nvvm_sqrt_rm_d>;
15211521
def INT_NVVM_SQRT_RP_D : F_MATH_1<"sqrt.rp.f64", F64RT, F64RT, int_nvvm_sqrt_rp_d>;
15221522

1523+
def fsqrt_approx : PatFrags<(ops node:$a),
1524+
[(fsqrt node:$a),
1525+
(int_nvvm_sqrt_f node:$a)], [{
1526+
return !usePrecSqrtF32(N);
1527+
}]>;
1528+
15231529
// nvvm_sqrt intrinsic
1524-
def : Pat<(int_nvvm_sqrt_f f32:$a),
1525-
(INT_NVVM_SQRT_RN_FTZ_F $a)>, Requires<[doF32FTZ, do_SQRTF32_RN]>;
1526-
def : Pat<(int_nvvm_sqrt_f f32:$a),
1527-
(INT_NVVM_SQRT_RN_F $a)>, Requires<[do_SQRTF32_RN]>;
1528-
def : Pat<(int_nvvm_sqrt_f f32:$a),
1529-
(INT_NVVM_SQRT_APPROX_FTZ_F $a)>, Requires<[doF32FTZ]>;
1530-
def : Pat<(int_nvvm_sqrt_f f32:$a),
1531-
(INT_NVVM_SQRT_APPROX_F $a)>;
1530+
def : Pat<(int_nvvm_sqrt_f f32:$a), (INT_NVVM_SQRT_RN_FTZ_F $a)>, Requires<[doF32FTZ]>;
1531+
def : Pat<(int_nvvm_sqrt_f f32:$a), (INT_NVVM_SQRT_RN_F $a)>;
1532+
1533+
def : Pat<(fsqrt_approx f32:$a), (INT_NVVM_SQRT_APPROX_FTZ_F $a)>, Requires<[doF32FTZ]>;
1534+
def : Pat<(fsqrt_approx f32:$a), (INT_NVVM_SQRT_APPROX_F $a)>;
15321535

15331536
//
15341537
// Rsqrt
@@ -1551,20 +1554,14 @@ def: Pat<(fdiv f32imm_1, (int_nvvm_sqrt_approx_f f32:$a)),
15511554
def: Pat<(fdiv f32imm_1, (int_nvvm_sqrt_approx_ftz_f f32:$a)),
15521555
(INT_NVVM_RSQRT_APPROX_FTZ_F $a)>,
15531556
Requires<[doRsqrtOpt]>;
1554-
// same for int_nvvm_sqrt_f when non-precision sqrt is requested
1555-
def: Pat<(fdiv f32imm_1, (int_nvvm_sqrt_f f32:$a)),
1556-
(INT_NVVM_RSQRT_APPROX_F $a)>,
1557-
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
1558-
def: Pat<(fdiv f32imm_1, (int_nvvm_sqrt_f f32:$a)),
1559-
(INT_NVVM_RSQRT_APPROX_FTZ_F $a)>,
1560-
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;
15611557

1562-
def: Pat<(fdiv f32imm_1, (fsqrt f32:$a)),
1558+
// same for int_nvvm_sqrt_f when non-precision sqrt is requested
1559+
def: Pat<(fdiv f32imm_1, (fsqrt_approx f32:$a)),
15631560
(INT_NVVM_RSQRT_APPROX_F $a)>,
1564-
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
1565-
def: Pat<(fdiv f32imm_1, (fsqrt f32:$a)),
1561+
Requires<[doRsqrtOpt, doNoF32FTZ]>;
1562+
def: Pat<(fdiv f32imm_1, (fsqrt_approx f32:$a)),
15661563
(INT_NVVM_RSQRT_APPROX_FTZ_F $a)>,
1567-
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;
1564+
Requires<[doRsqrtOpt, doF32FTZ]>;
15681565
//
15691566
// Add
15701567
//

0 commit comments

Comments
 (0)