Skip to content

Commit 256ac3f

Browse files
committed
support fadd, fsub, fmul, fma and load on v2f32
1 parent 69d78c2 commit 256ac3f

File tree

2 files changed

+41
-4
lines changed

2 files changed

+41
-4
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,10 +1099,14 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
10991099
// Vector Setting
11001100
unsigned VecType = NVPTX::PTXLdStInstCode::Scalar;
11011101
if (SimpleVT.isVector()) {
1102-
assert((Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8) &&
1103-
"Unexpected vector type");
1104-
// v2f16/v2bf16/v2i16 is loaded using ld.b32
1105-
FromTypeWidth = 32;
1102+
if (Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8)
1103+
// v2f16/v2bf16/v2i16 is loaded using ld.b32
1104+
FromTypeWidth = 32;
1105+
else if (LoadedVT == MVT::v2f32)
1106+
// v2f32 is loaded using ld.b64
1107+
FromTypeWidth = 64;
1108+
else
1109+
llvm_unreachable("Unexpected vector type");
11061110
}
11071111

11081112
if (PlainLoad && (PlainLoad->getExtensionType() == ISD::SEXTLOAD))

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,18 @@ multiclass F3_fma_component<string OpcStr, SDNode OpNode> {
405405
!strconcat(OpcStr, ".f32 \t$dst, $a, $b;"),
406406
[(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>,
407407
Requires<[allowFMA]>;
408+
def f32x2rr_ftz :
409+
NVPTXInst<(outs Int64Regs:$dst),
410+
(ins Int64Regs:$a, Int64Regs:$b),
411+
!strconcat(OpcStr, ".ftz.f32x2 \t$dst, $a, $b;"),
412+
[(set v2f32:$dst, (OpNode v2f32:$a, v2f32:$b))]>,
413+
Requires<[allowFMA, doF32FTZ]>;
414+
def f32x2rr :
415+
NVPTXInst<(outs Int64Regs:$dst),
416+
(ins Int64Regs:$a, Int64Regs:$b),
417+
!strconcat(OpcStr, ".f32x2 \t$dst, $a, $b;"),
418+
[(set v2f32:$dst, (OpNode v2f32:$a, v2f32:$b))]>,
419+
Requires<[allowFMA]>;
408420

409421
def f16rr_ftz :
410422
NVPTXInst<(outs Int16Regs:$dst),
@@ -529,6 +541,18 @@ multiclass F3_fma_component<string OpcStr, SDNode OpNode> {
529541
!strconcat(OpcStr, ".rn.bf16x2 \t$dst, $a, $b;"),
530542
[(set v2bf16:$dst, (OpNode v2bf16:$a, v2bf16:$b))]>,
531543
Requires<[hasBF16Math, noFMA]>;
544+
def _rnf32x2rr_ftz :
545+
NVPTXInst<(outs Int64Regs:$dst),
546+
(ins Int64Regs:$a, Int64Regs:$b),
547+
!strconcat(OpcStr, ".rn.ftz.f32x2 \t$dst, $a, $b;"),
548+
[(set v2f32:$dst, (OpNode v2f32:$a, v2f32:$b))]>,
549+
Requires<[hasF32x2Instructions, noFMA, doF32FTZ]>;
550+
def _rnf32x2rr :
551+
NVPTXInst<(outs Int64Regs:$dst),
552+
(ins Int64Regs:$a, Int64Regs:$b),
553+
!strconcat(OpcStr, ".rn.f32x2 \t$dst, $a, $b;"),
554+
[(set v2f32:$dst, (OpNode v2f32:$a, v2f32:$b))]>,
555+
Requires<[hasF32x2Instructions, noFMA]>;
532556
}
533557

534558
// Template for operations which take two f32 or f64 operands. Provides three
@@ -1432,6 +1456,13 @@ multiclass FMA_BF16<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred
14321456
Requires<[hasBF16Math, Pred]>;
14331457
}
14341458

1459+
class FMA_F32x2<string OpcStr, Predicate Pred>
1460+
: NVPTXInst<(outs Int64Regs:$res),
1461+
(ins Int64Regs:$a, Int64Regs:$b, Int64Regs:$c),
1462+
OpcStr # ".f32x2 \t$res, $a, $b, $c;",
1463+
[(set v2f32:$res, (fma v2f32:$a, v2f32:$b, v2f32:$c))]>,
1464+
Requires<[hasF32x2Instructions, Pred]>;
1465+
14351466
defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", f16, Int16Regs, doF32FTZ>;
14361467
defm FMA16 : FMA_F16<"fma.rn.f16", f16, Int16Regs, True>;
14371468
defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Int32Regs, doF32FTZ>;
@@ -1440,6 +1471,8 @@ defm BFMA16 : FMA_BF16<"fma.rn.bf16", bf16, Int16Regs, True>;
14401471
defm BFMA16x2 : FMA_BF16<"fma.rn.bf16x2", v2bf16, Int32Regs, True>;
14411472
defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>;
14421473
defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True>;
1474+
def FMA32x2_ftz : FMA_F32x2<"fma.rn.ftz", doF32FTZ>;
1475+
def FMA32x2 : FMA_F32x2<"fma.rn", True>;
14431476
defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, True>;
14441477

14451478
// sin/cos

0 commit comments

Comments
 (0)