Skip to content

Commit cc1b9ac

Browse files
committed
[NVPTX] Lower fp16 fminnum, fmaxnum to native on sm_80.
Reviewed By: bkramer, tra Differential Revision: https://reviews.llvm.org/D117122
1 parent 9c9119a commit cc1b9ac

File tree

3 files changed

+75
-5
lines changed

3 files changed

+75
-5
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -560,10 +560,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
560560
setOperationAction(Op, MVT::f64, Legal);
561561
setOperationAction(Op, MVT::v2f16, Expand);
562562
}
563-
setOperationAction(ISD::FMINNUM, MVT::f16, Promote);
564-
setOperationAction(ISD::FMAXNUM, MVT::f16, Promote);
565-
setOperationAction(ISD::FMINIMUM, MVT::f16, Promote);
566-
setOperationAction(ISD::FMAXIMUM, MVT::f16, Promote);
563+
// max.f16 is supported on sm_80+.
564+
if (STI.allowFP16Math() && STI.getSmVersion() >= 80 &&
565+
STI.getPTXVersion() >= 70) {
566+
setOperationAction(ISD::FMINNUM, MVT::f16, Legal);
567+
setOperationAction(ISD::FMAXNUM, MVT::f16, Legal);
568+
setOperationAction(ISD::FMINNUM, MVT::v2f16, Legal);
569+
setOperationAction(ISD::FMAXNUM, MVT::v2f16, Legal);
570+
}
567571

568572
// No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate.
569573
// No FPOW or FREM in PTX.

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,32 @@ multiclass F3<string OpcStr, SDNode OpNode> {
249249
(ins Float32Regs:$a, f32imm:$b),
250250
!strconcat(OpcStr, ".f32 \t$dst, $a, $b;"),
251251
[(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>;
252+
253+
def f16rr_ftz :
254+
NVPTXInst<(outs Float16Regs:$dst),
255+
(ins Float16Regs:$a, Float16Regs:$b),
256+
!strconcat(OpcStr, ".ftz.f16 \t$dst, $a, $b;"),
257+
[(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>,
258+
Requires<[useFP16Math, doF32FTZ]>;
259+
def f16rr :
260+
NVPTXInst<(outs Float16Regs:$dst),
261+
(ins Float16Regs:$a, Float16Regs:$b),
262+
!strconcat(OpcStr, ".f16 \t$dst, $a, $b;"),
263+
[(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>,
264+
Requires<[useFP16Math]>;
265+
266+
def f16x2rr_ftz :
267+
NVPTXInst<(outs Float16x2Regs:$dst),
268+
(ins Float16x2Regs:$a, Float16x2Regs:$b),
269+
!strconcat(OpcStr, ".ftz.f16x2 \t$dst, $a, $b;"),
270+
[(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>,
271+
Requires<[useFP16Math, doF32FTZ]>;
272+
def f16x2rr :
273+
NVPTXInst<(outs Float16x2Regs:$dst),
274+
(ins Float16x2Regs:$a, Float16x2Regs:$b),
275+
!strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
276+
[(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>,
277+
Requires<[useFP16Math]>;
252278
}
253279

254280
// Template for instructions which take three FP args. The

llvm/test/CodeGen/NVPTX/math-intrins.ll

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
; RUN: llc < %s | FileCheck %s
1+
; RUN: llc < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOF16
2+
; RUN: llc < %s -mcpu=sm_80 | FileCheck %s --check-prefixes=CHECK,CHECK-F16
3+
; RUN: llc < %s -mcpu=sm_80 --nvptx-no-f16-math | FileCheck %s --check-prefixes=CHECK,CHECK-NOF16
24
target triple = "nvptx64-nvidia-cuda"
35

46
; Checks that llvm intrinsics for math functions are correctly lowered to PTX.
@@ -17,10 +19,14 @@ declare float @llvm.trunc.f32(float) #0
1719
declare double @llvm.trunc.f64(double) #0
1820
declare float @llvm.fabs.f32(float) #0
1921
declare double @llvm.fabs.f64(double) #0
22+
declare half @llvm.minnum.f16(half, half) #0
2023
declare float @llvm.minnum.f32(float, float) #0
2124
declare double @llvm.minnum.f64(double, double) #0
25+
declare <2 x half> @llvm.minnum.v2f16(<2 x half>, <2 x half>) #0
26+
declare half @llvm.maxnum.f16(half, half) #0
2227
declare float @llvm.maxnum.f32(float, float) #0
2328
declare double @llvm.maxnum.f64(double, double) #0
29+
declare <2 x half> @llvm.maxnum.v2f16(<2 x half>, <2 x half>) #0
2430
declare float @llvm.fma.f32(float, float, float) #0
2531
declare double @llvm.fma.f64(double, double, double) #0
2632

@@ -193,6 +199,14 @@ define double @abs_double(double %a) {
193199

194200
; ---- min ----
195201

202+
; CHECK-LABEL: min_half
203+
define half @min_half(half %a, half %b) {
204+
; CHECK-NOF16: min.f32
205+
; CHECK-F16: min.f16
206+
%x = call half @llvm.minnum.f16(half %a, half %b)
207+
ret half %x
208+
}
209+
196210
; CHECK-LABEL: min_float
197211
define float @min_float(float %a, float %b) {
198212
; CHECK: min.f32
@@ -228,8 +242,25 @@ define double @min_double(double %a, double %b) {
228242
ret double %x
229243
}
230244

245+
; CHECK-LABEL: min_v2half
246+
define <2 x half> @min_v2half(<2 x half> %a, <2 x half> %b) {
247+
; CHECK-NOF16: min.f32
248+
; CHECK-NOF16: min.f32
249+
; CHECK-F16: min.f16x2
250+
%x = call <2 x half> @llvm.minnum.v2f16(<2 x half> %a, <2 x half> %b)
251+
ret <2 x half> %x
252+
}
253+
231254
; ---- max ----
232255

256+
; CHECK-LABEL: max_half
257+
define half @max_half(half %a, half %b) {
258+
; CHECK-NOF16: max.f32
259+
; CHECK-F16: max.f16
260+
%x = call half @llvm.maxnum.f16(half %a, half %b)
261+
ret half %x
262+
}
263+
233264
; CHECK-LABEL: max_imm1
234265
define float @max_imm1(float %a) {
235266
; CHECK: max.f32
@@ -265,6 +296,15 @@ define double @max_double(double %a, double %b) {
265296
ret double %x
266297
}
267298

299+
; CHECK-LABEL: max_v2half
300+
define <2 x half> @max_v2half(<2 x half> %a, <2 x half> %b) {
301+
; CHECK-NOF16: max.f32
302+
; CHECK-NOF16: max.f32
303+
; CHECK-F16: max.f16x2
304+
%x = call <2 x half> @llvm.maxnum.v2f16(<2 x half> %a, <2 x half> %b)
305+
ret <2 x half> %x
306+
}
307+
268308
; ---- fma ----
269309

270310
; CHECK-LABEL: @fma_float

0 commit comments

Comments
 (0)