diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index bb0aeb493ed48..e18c75cd49a95 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -2793,14 +2793,18 @@ static SDValue lowerFREM(SDValue Op, SelectionDAG &DAG, EVT Ty = Op.getValueType(); SDNodeFlags Flags = Op->getFlags(); - SDValue Div = DAG.getNode(ISD::FDIV, DL, Ty, X, Y, Flags); - SDValue Trunc = DAG.getNode(ISD::FTRUNC, DL, Ty, Div, Flags); + // fdiv can still generate inf and nan when nnan and ninf are set. + SDNodeFlags NewFlags = Flags; + NewFlags.setNoNaNs(false); + NewFlags.setNoInfs(false); + SDValue Div = DAG.getNode(ISD::FDIV, DL, Ty, X, Y, NewFlags); + SDValue Trunc = DAG.getNode(ISD::FTRUNC, DL, Ty, Div, NewFlags); SDValue Mul = DAG.getNode(ISD::FMUL, DL, Ty, Trunc, Y, - Flags | SDNodeFlags::AllowContract); + NewFlags | SDNodeFlags::AllowContract); SDValue Sub = DAG.getNode(ISD::FSUB, DL, Ty, X, Mul, - Flags | SDNodeFlags::AllowContract); + NewFlags | SDNodeFlags::AllowContract); - if (AllowUnsafeFPMath || Flags.hasNoInfs()) + if (AllowUnsafeFPMath || (Flags.hasNoInfs() && Flags.hasApproximateFuncs())) return Sub; // If Y is infinite, return X diff --git a/llvm/test/CodeGen/NVPTX/frem-ninf-nnan.ll b/llvm/test/CodeGen/NVPTX/frem-ninf-nnan.ll new file mode 100644 index 0000000000000..b1d498257b5ad --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/frem-ninf-nnan.ll @@ -0,0 +1,11 @@ +; RUN: llc %s --stop-after=finalize-isel -mcpu=sm_60 -o - | FileCheck %s + +target triple = "nvptx64-unknown-cuda" + +define float @frem_ninf_nnan(float %a, float %b) { + ; CHECK: nnan ninf FDIV32rr_prec + ; CHECK-NOT: nnan ninf contract FNEGf32 + ; CHECK: contract FNEGf32 + %r = frem ninf nnan float %a, %b + ret float %r +} diff --git a/llvm/test/CodeGen/NVPTX/frem.ll b/llvm/test/CodeGen/NVPTX/frem.ll index 5805aed1bebe6..479205cd08119 100644 --- a/llvm/test/CodeGen/NVPTX/frem.ll +++ b/llvm/test/CodeGen/NVPTX/frem.ll @@ -147,14 +147,14 @@ define half @frem_f16_ninf(half %a, half %b) { ; NORMAL-NEXT: ld.param.b16 %rs2, [frem_f16_ninf_param_1]; ; NORMAL-NEXT: cvt.f32.f16 %r1, %rs2; ; NORMAL-NEXT: cvt.f32.f16 %r2, %rs1; -; NORMAL-NEXT: div.rn.f32 %r3, %r2, %r1; +; NORMAL-NEXT: div.approx.f32 %r3, %r2, %r1; ; NORMAL-NEXT: cvt.rzi.f32.f32 %r4, %r3; ; NORMAL-NEXT: neg.f32 %r5, %r4; ; NORMAL-NEXT: fma.rn.f32 %r6, %r5, %r1, %r2; ; NORMAL-NEXT: cvt.rn.f16.f32 %rs3, %r6; ; NORMAL-NEXT: st.param.b16 [func_retval0], %rs3; ; NORMAL-NEXT: ret; - %r = frem ninf half %a, %b + %r = frem ninf afn half %a, %b ret half %r } @@ -180,13 +180,13 @@ define float @frem_f32_ninf(float %a, float %b) { ; NORMAL-NEXT: // %bb.0: ; NORMAL-NEXT: ld.param.b32 %r1, [frem_f32_ninf_param_0]; ; NORMAL-NEXT: ld.param.b32 %r2, [frem_f32_ninf_param_1]; -; NORMAL-NEXT: div.rn.f32 %r3, %r1, %r2; +; NORMAL-NEXT: div.approx.f32 %r3, %r1, %r2; ; NORMAL-NEXT: cvt.rzi.f32.f32 %r4, %r3; ; NORMAL-NEXT: neg.f32 %r5, %r4; ; NORMAL-NEXT: fma.rn.f32 %r6, %r5, %r2, %r1; ; NORMAL-NEXT: st.param.b32 [func_retval0], %r6; ; NORMAL-NEXT: ret; - %r = frem ninf float %a, %b + %r = frem ninf afn float %a, %b ret float %r } @@ -218,7 +218,7 @@ define double @frem_f64_ninf(double %a, double %b) { ; NORMAL-NEXT: fma.rn.f64 %rd6, %rd5, %rd2, %rd1; ; NORMAL-NEXT: st.param.b64 [func_retval0], %rd6; ; NORMAL-NEXT: ret; - %r = frem ninf double %a, %b + %r = frem ninf afn double %a, %b ret double %r }