Skip to content

Commit 1b4ec1f

Browse files
committed
enable ftz support
And temporarily disable -O3 in testing as it exposes an existing bug with how test_extract_i() is lowered when optimized.
1 parent e3dc226 commit 1b4ec1f

File tree

5 files changed

+426
-173
lines changed

5 files changed

+426
-173
lines changed

llvm/include/llvm/Target/TargetSelectionDAG.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ def SDTPtrAddOp : SDTypeProfile<1, 2, [ // ptradd
115115
def SDTIntBinOp : SDTypeProfile<1, 2, [ // add, and, or, xor, udiv, etc.
116116
SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisInt<0>
117117
]>;
118+
def SDTIntTernaryOp : SDTypeProfile<1, 3, [ // fma32x2
119+
SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>, SDTCisInt<0>
120+
]>;
118121
def SDTIntShiftOp : SDTypeProfile<1, 2, [ // shl, sra, srl
119122
SDTCisSameAs<0, 1>, SDTCisInt<0>, SDTCisInt<2>
120123
]>;

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,6 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
121121
case NVPTXISD::SETP_BF16X2:
122122
SelectSETP_BF16X2(N);
123123
return;
124-
case NVPTXISD::FADD_F32X2:
125-
case NVPTXISD::FSUB_F32X2:
126-
case NVPTXISD::FMUL_F32X2:
127-
case NVPTXISD::FMA_F32X2:
128-
SelectF32X2Op(N);
129-
return;
130124
case NVPTXISD::LoadV2:
131125
case NVPTXISD::LoadV4:
132126
if (tryLoadVector(N))
@@ -305,30 +299,6 @@ bool NVPTXDAGToDAGISel::SelectSETP_BF16X2(SDNode *N) {
305299
return true;
306300
}
307301

308-
void NVPTXDAGToDAGISel::SelectF32X2Op(SDNode *N) {
309-
unsigned Opcode;
310-
switch (N->getOpcode()) {
311-
case NVPTXISD::FADD_F32X2:
312-
Opcode = NVPTX::FADD_F32X2;
313-
break;
314-
case NVPTXISD::FSUB_F32X2:
315-
Opcode = NVPTX::FSUB_F32X2;
316-
break;
317-
case NVPTXISD::FMUL_F32X2:
318-
Opcode = NVPTX::FMUL_F32X2;
319-
break;
320-
case NVPTXISD::FMA_F32X2:
321-
Opcode = NVPTX::FMA_F32X2;
322-
break;
323-
default:
324-
llvm_unreachable("Unexpected opcode!");
325-
}
326-
SDLoc DL(N);
327-
SmallVector<SDValue> NewOps(N->ops());
328-
SDNode *NewNode = CurDAG->getMachineNode(Opcode, DL, MVT::i64, NewOps);
329-
ReplaceNode(N, NewNode);
330-
}
331-
332302
// Find all instances of extract_vector_elt that use this v2f16 vector
333303
// and coalesce them into a scattering move instruction.
334304
bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
8888
bool tryConstantFP(SDNode *N);
8989
bool SelectSETP_F16X2(SDNode *N);
9090
bool SelectSETP_BF16X2(SDNode *N);
91-
void SelectF32X2Op(SDNode *N);
9291
bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
9392
void SelectV2I64toI128(SDNode *N);
9493
void SelectI128toV2I64(SDNode *N);

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,27 +1581,50 @@ def INT_NVVM_ADD_RM_D : F_MATH_2<"add.rm.f64 \t$dst, $src0, $src1;",
15811581
def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64 \t$dst, $src0, $src1;",
15821582
Float64Regs, Float64Regs, Float64Regs, int_nvvm_add_rp_d>;
15831583

1584-
// F32x2 ops (sm_100+)
1585-
1586-
def FADD_F32X2 : NVPTXInst<(outs Int64Regs:$res),
1587-
(ins Int64Regs:$a, Int64Regs:$b),
1588-
"add.rn.f32x2 \t$res, $a, $b;", []>,
1589-
Requires<[hasF32x2Instructions]>;
1590-
1591-
def FSUB_F32X2 : NVPTXInst<(outs Int64Regs:$res),
1592-
(ins Int64Regs:$a, Int64Regs:$b),
1593-
"sub.rn.f32x2 \t$res, $a, $b;", []>,
1594-
Requires<[hasF32x2Instructions]>;
1595-
1596-
def FMUL_F32X2 : NVPTXInst<(outs Int64Regs:$res),
1597-
(ins Int64Regs:$a, Int64Regs:$b),
1598-
"mul.rn.f32x2 \t$res, $a, $b;", []>,
1599-
Requires<[hasF32x2Instructions]>;
1600-
1601-
def FMA_F32X2 : NVPTXInst<(outs Int64Regs:$res),
1602-
(ins Int64Regs:$a, Int64Regs:$b, Int64Regs:$c),
1603-
"fma.rn.f32x2 \t$res, $a, $b;", []>,
1604-
Requires<[hasF32x2Instructions]>;
1584+
// packed f32 ops (sm_100+)
1585+
class F32x2Op2<string OpcStr, Predicate Pred>
1586+
: NVPTXInst<(outs Int64Regs:$res),
1587+
(ins Int64Regs:$a, Int64Regs:$b),
1588+
OpcStr # ".f32x2 \t$res, $a, $b;", []>,
1589+
Requires<[hasF32x2Instructions, Pred]>;
1590+
class F32x2Op3<string OpcStr, Predicate Pred>
1591+
: NVPTXInst<(outs Int64Regs:$res),
1592+
(ins Int64Regs:$a, Int64Regs:$b, Int64Regs:$c),
1593+
OpcStr # ".f32x2 \t$res, $a, $b, $c;", []>,
1594+
Requires<[hasF32x2Instructions, Pred]>;
1595+
1596+
def fadd32x2_nvptx : SDNode<"NVPTXISD::FADD_F32X2", SDTIntBinOp>;
1597+
def fsub32x2_nvptx : SDNode<"NVPTXISD::FSUB_F32X2", SDTIntBinOp>;
1598+
def fmul32x2_nvptx : SDNode<"NVPTXISD::FMUL_F32X2", SDTIntBinOp>;
1599+
def fma32x2_nvptx : SDNode<"NVPTXISD::FMA_F32X2", SDTIntTernaryOp>;
1600+
1601+
def FADD32x2 : F32x2Op2<"add.rn", doNoF32FTZ>;
1602+
def FSUB32x2 : F32x2Op2<"sub.rn", doNoF32FTZ>;
1603+
def FMUL32x2 : F32x2Op2<"mul.rn", doNoF32FTZ>;
1604+
def FMA32x2 : F32x2Op3<"fma.rn", doNoF32FTZ>;
1605+
1606+
def : Pat<(fadd32x2_nvptx i64:$a, i64:$b),
1607+
(FADD32x2 $a, $b)>, Requires<[doNoF32FTZ]>;
1608+
def : Pat<(fsub32x2_nvptx i64:$a, i64:$b),
1609+
(FSUB32x2 $a, $b)>, Requires<[doNoF32FTZ]>;
1610+
def : Pat<(fmul32x2_nvptx i64:$a, i64:$b),
1611+
(FMUL32x2 $a, $b)>, Requires<[doNoF32FTZ]>;
1612+
def : Pat<(fma32x2_nvptx i64:$a, i64:$b, i64:$c),
1613+
(FMA32x2 $a, $b, $c)>, Requires<[doNoF32FTZ]>;
1614+
1615+
def FADD32x2_ftz : F32x2Op2<"add.rn.ftz", doF32FTZ>;
1616+
def FSUB32x2_ftz : F32x2Op2<"sub.rn.ftz", doF32FTZ>;
1617+
def FMUL32x2_ftz : F32x2Op2<"mul.rn.ftz", doF32FTZ>;
1618+
def FMA32x2_ftz : F32x2Op3<"fma.rn.ftz", doF32FTZ>;
1619+
1620+
def : Pat<(fadd32x2_nvptx i64:$a, i64:$b),
1621+
(FADD32x2_ftz $a, $b)>, Requires<[doF32FTZ]>;
1622+
def : Pat<(fsub32x2_nvptx i64:$a, i64:$b),
1623+
(FSUB32x2_ftz $a, $b)>, Requires<[doF32FTZ]>;
1624+
def : Pat<(fmul32x2_nvptx i64:$a, i64:$b),
1625+
(FMUL32x2_ftz $a, $b)>, Requires<[doF32FTZ]>;
1626+
def : Pat<(fma32x2_nvptx i64:$a, i64:$b, i64:$c),
1627+
(FMA32x2_ftz $a, $b, $c)>, Requires<[doF32FTZ]>;
16051628

16061629
//
16071630
// BFIND

0 commit comments

Comments
 (0)