Skip to content

Commit 8080007

Browse files
committed
[NVPTX] Rework and cleanup FTZ ISel
1 parent 5dbd877 commit 8080007

35 files changed

+1296
-1541
lines changed

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp

Lines changed: 74 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -154,73 +154,114 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
154154
llvm_unreachable("Invalid conversion modifier");
155155
}
156156

157+
void NVPTXInstPrinter::printFTZFlag(const MCInst *MI, int OpNum, raw_ostream &O) {
158+
const MCOperand &MO = MI->getOperand(OpNum);
159+
const int Imm = MO.getImm();
160+
if (Imm)
161+
O << ".ftz";
162+
}
163+
157164
void NVPTXInstPrinter::printCmpMode(const MCInst *MI, int OpNum, raw_ostream &O,
158165
StringRef Modifier) {
159166
const MCOperand &MO = MI->getOperand(OpNum);
160167
int64_t Imm = MO.getImm();
161168

162-
if (Modifier == "ftz") {
163-
// FTZ flag
164-
if (Imm & NVPTX::PTXCmpMode::FTZ_FLAG)
165-
O << ".ftz";
166-
return;
167-
} else if (Modifier == "base") {
168-
switch (Imm & NVPTX::PTXCmpMode::BASE_MASK) {
169+
if (Modifier == "FCmp") {
170+
switch (Imm) {
169171
default:
170172
return;
171173
case NVPTX::PTXCmpMode::EQ:
172-
O << ".eq";
174+
O << "eq";
173175
return;
174176
case NVPTX::PTXCmpMode::NE:
175-
O << ".ne";
177+
O << "ne";
176178
return;
177179
case NVPTX::PTXCmpMode::LT:
178-
O << ".lt";
180+
O << "lt";
179181
return;
180182
case NVPTX::PTXCmpMode::LE:
181-
O << ".le";
183+
O << "le";
182184
return;
183185
case NVPTX::PTXCmpMode::GT:
184-
O << ".gt";
186+
O << "gt";
185187
return;
186188
case NVPTX::PTXCmpMode::GE:
187-
O << ".ge";
188-
return;
189-
case NVPTX::PTXCmpMode::LO:
190-
O << ".lo";
191-
return;
192-
case NVPTX::PTXCmpMode::LS:
193-
O << ".ls";
194-
return;
195-
case NVPTX::PTXCmpMode::HI:
196-
O << ".hi";
197-
return;
198-
case NVPTX::PTXCmpMode::HS:
199-
O << ".hs";
189+
O << "ge";
200190
return;
201191
case NVPTX::PTXCmpMode::EQU:
202-
O << ".equ";
192+
O << "equ";
203193
return;
204194
case NVPTX::PTXCmpMode::NEU:
205-
O << ".neu";
195+
O << "neu";
206196
return;
207197
case NVPTX::PTXCmpMode::LTU:
208-
O << ".ltu";
198+
O << "ltu";
209199
return;
210200
case NVPTX::PTXCmpMode::LEU:
211-
O << ".leu";
201+
O << "leu";
212202
return;
213203
case NVPTX::PTXCmpMode::GTU:
214-
O << ".gtu";
204+
O << "gtu";
215205
return;
216206
case NVPTX::PTXCmpMode::GEU:
217-
O << ".geu";
207+
O << "geu";
218208
return;
219209
case NVPTX::PTXCmpMode::NUM:
220-
O << ".num";
210+
O << "num";
221211
return;
222212
case NVPTX::PTXCmpMode::NotANumber:
223-
O << ".nan";
213+
O << "nan";
214+
return;
215+
}
216+
}
217+
if (Modifier == "ICmp") {
218+
switch (Imm) {
219+
default:
220+
llvm_unreachable("Invalid ICmp mode");
221+
case NVPTX::PTXCmpMode::EQ:
222+
O << "eq";
223+
return;
224+
case NVPTX::PTXCmpMode::NE:
225+
O << "ne";
226+
return;
227+
case NVPTX::PTXCmpMode::LT:
228+
case NVPTX::PTXCmpMode::LTU:
229+
O << "lt";
230+
return;
231+
case NVPTX::PTXCmpMode::LE:
232+
case NVPTX::PTXCmpMode::LEU:
233+
O << "le";
234+
return;
235+
case NVPTX::PTXCmpMode::GT:
236+
case NVPTX::PTXCmpMode::GTU:
237+
O << "gt";
238+
return;
239+
case NVPTX::PTXCmpMode::GE:
240+
case NVPTX::PTXCmpMode::GEU:
241+
O << "ge";
242+
return;
243+
244+
}
245+
}
246+
if (Modifier == "IType") {
247+
switch (Imm) {
248+
default:
249+
llvm_unreachable("Invalid ICmp mode");
250+
case NVPTX::PTXCmpMode::EQ:
251+
case NVPTX::PTXCmpMode::NE:
252+
O << "b";
253+
return;
254+
case NVPTX::PTXCmpMode::LT:
255+
case NVPTX::PTXCmpMode::LE:
256+
case NVPTX::PTXCmpMode::GT:
257+
case NVPTX::PTXCmpMode::GE:
258+
O << "s";
259+
return;
260+
case NVPTX::PTXCmpMode::LTU:
261+
case NVPTX::PTXCmpMode::LEU:
262+
case NVPTX::PTXCmpMode::GTU:
263+
case NVPTX::PTXCmpMode::GEU:
264+
O << "u";
224265
return;
225266
}
226267
}

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class NVPTXInstPrinter : public MCInstPrinter {
5454
void printCTAGroup(const MCInst *MI, int OpNum, raw_ostream &O);
5555
void printCallOperand(const MCInst *MI, int OpNum, raw_ostream &O,
5656
StringRef Modifier = {});
57+
void printFTZFlag(const MCInst *MI, int OpNum, raw_ostream &O);
5758
};
5859

5960
}

llvm/lib/Target/NVPTX/NVPTX.h

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include "llvm/Support/AtomicOrdering.h"
2020
#include "llvm/Support/CodeGen.h"
2121
#include "llvm/Target/TargetMachine.h"
22-
22+
#include "llvm/CodeGen/ISDOpcodes.h"
2323
namespace llvm {
2424
class FunctionPass;
2525
class MachineFunctionPass;
@@ -218,28 +218,21 @@ enum CvtMode {
218218
/// PTXCmpMode - Comparison mode enumeration
219219
namespace PTXCmpMode {
220220
enum CmpMode {
221-
EQ = 0,
222-
NE,
223-
LT,
224-
LE,
225-
GT,
226-
GE,
227-
LO,
228-
LS,
229-
HI,
230-
HS,
231-
EQU,
232-
NEU,
233-
LTU,
234-
LEU,
235-
GTU,
236-
GEU,
237-
NUM,
221+
EQ = ISD::SETEQ,
222+
NE = ISD::SETNE,
223+
LT = ISD::SETLT,
224+
LE = ISD::SETLE,
225+
GT = ISD::SETGT,
226+
GE = ISD::SETGE,
227+
EQU = ISD::SETUEQ,
228+
NEU = ISD::SETUNE,
229+
LTU = ISD::SETULT,
230+
LEU = ISD::SETULE,
231+
GTU = ISD::SETUGT,
232+
GEU = ISD::SETUGE,
233+
NUM = ISD::SETO,
238234
// NAN is a MACRO
239-
NotANumber,
240-
241-
BASE_MASK = 0xFF,
242-
FTZ_FLAG = 0x100
235+
NotANumber = ISD::SETUO,
243236
};
244237
}
245238

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -363,23 +363,29 @@ bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) {
363363

364364
// Map ISD:CONDCODE value to appropriate CmpMode expected by
365365
// NVPTXInstPrinter::printCmpMode()
366-
static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) {
366+
SDValue NVPTXDAGToDAGISel::getPTXCmpMode(const CondCodeSDNode &CondCode) {
367367
using NVPTX::PTXCmpMode::CmpMode;
368-
unsigned PTXCmpMode = [](ISD::CondCode CC) {
368+
const unsigned PTXCmpMode = [](ISD::CondCode CC) {
369369
switch (CC) {
370370
default:
371371
llvm_unreachable("Unexpected condition code.");
372372
case ISD::SETOEQ:
373+
case ISD::SETEQ:
373374
return CmpMode::EQ;
374375
case ISD::SETOGT:
376+
case ISD::SETGT:
375377
return CmpMode::GT;
376378
case ISD::SETOGE:
379+
case ISD::SETGE:
377380
return CmpMode::GE;
378381
case ISD::SETOLT:
382+
case ISD::SETLT:
379383
return CmpMode::LT;
380384
case ISD::SETOLE:
385+
case ISD::SETLE:
381386
return CmpMode::LE;
382387
case ISD::SETONE:
388+
case ISD::SETNE:
383389
return CmpMode::NE;
384390
case ISD::SETO:
385391
return CmpMode::NUM;
@@ -397,45 +403,29 @@ static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) {
397403
return CmpMode::LEU;
398404
case ISD::SETUNE:
399405
return CmpMode::NEU;
400-
case ISD::SETEQ:
401-
return CmpMode::EQ;
402-
case ISD::SETGT:
403-
return CmpMode::GT;
404-
case ISD::SETGE:
405-
return CmpMode::GE;
406-
case ISD::SETLT:
407-
return CmpMode::LT;
408-
case ISD::SETLE:
409-
return CmpMode::LE;
410-
case ISD::SETNE:
411-
return CmpMode::NE;
412406
}
413407
}(CondCode.get());
414-
415-
if (FTZ)
416-
PTXCmpMode |= NVPTX::PTXCmpMode::FTZ_FLAG;
417-
418-
return PTXCmpMode;
408+
return CurDAG->getTargetConstant(PTXCmpMode, SDLoc(), MVT::i32);
419409
}
420410

421411
bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
422-
unsigned PTXCmpMode =
423-
getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ());
412+
SDValue PTXCmpMode =
413+
getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)));
424414
SDLoc DL(N);
425415
SDNode *SetP = CurDAG->getMachineNode(
426-
NVPTX::SETP_f16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0),
427-
N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32));
416+
NVPTX::SETP_f16x2rr, DL, MVT::i1, MVT::i1, {N->getOperand(0),
417+
N->getOperand(1), PTXCmpMode, CurDAG->getTargetConstant(useF32FTZ() ? 1 : 0, DL, MVT::i1)});
428418
ReplaceNode(N, SetP);
429419
return true;
430420
}
431421

432422
bool NVPTXDAGToDAGISel::SelectSETP_BF16X2(SDNode *N) {
433-
unsigned PTXCmpMode =
434-
getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ());
423+
SDValue PTXCmpMode =
424+
getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)));
435425
SDLoc DL(N);
436426
SDNode *SetP = CurDAG->getMachineNode(
437-
NVPTX::SETP_bf16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0),
438-
N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32));
427+
NVPTX::SETP_bf16x2rr, DL, MVT::i1, MVT::i1, {N->getOperand(0),
428+
N->getOperand(1), PTXCmpMode, CurDAG->getTargetConstant(useF32FTZ() ? 1 : 0, DL, MVT::i1)});
439429
ReplaceNode(N, SetP);
440430
return true;
441431
}
@@ -1953,7 +1943,7 @@ bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) {
19531943
llvm_unreachable("Unexpected opcode");
19541944
};
19551945

1956-
int Opcode = IsVec ? NVPTX::BFMA16x2rrr : NVPTX::BFMA16rrr;
1946+
int Opcode = IsVec ? NVPTX::FMA_BF16x2rrr : NVPTX::FMA_BF16rrr;
19571947
MachineSDNode *FMA = CurDAG->getMachineNode(Opcode, DL, VT, Operands);
19581948
ReplaceNode(N, FMA);
19591949
return true;

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,11 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
104104
}
105105

106106
bool SelectADDR(SDValue Addr, SDValue &Base, SDValue &Offset);
107+
SDValue getPTXCmpMode(const CondCodeSDNode &CondCode);
107108
SDValue selectPossiblyImm(SDValue V);
108109

109110
bool ChkMemSDNodeAddressSpace(SDNode *N, unsigned int spN) const;
110111

111-
static unsigned GetConvertOpcode(MVT DestTy, MVT SrcTy, LoadSDNode *N);
112-
113112
// Returns the Memory Order and Scope that the PTX memory instruction should
114113
// use, and inserts appropriate fence instruction before the memory
115114
// instruction, if needed to implement the instructions memory order. Required

0 commit comments

Comments
 (0)