Skip to content

Commit e161048

Browse files
committed
[NVPTX] Rework and cleanup FTZ ISel
1 parent 5dbd877 commit e161048

35 files changed

+1296
-1541
lines changed

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp

Lines changed: 74 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -154,73 +154,114 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
154154
llvm_unreachable("Invalid conversion modifier");
155155
}
156156

157+
void NVPTXInstPrinter::printFTZFlag(const MCInst *MI, int OpNum,
158+
raw_ostream &O) {
159+
const MCOperand &MO = MI->getOperand(OpNum);
160+
const int Imm = MO.getImm();
161+
if (Imm)
162+
O << ".ftz";
163+
}
164+
157165
void NVPTXInstPrinter::printCmpMode(const MCInst *MI, int OpNum, raw_ostream &O,
158166
StringRef Modifier) {
159167
const MCOperand &MO = MI->getOperand(OpNum);
160168
int64_t Imm = MO.getImm();
161169

162-
if (Modifier == "ftz") {
163-
// FTZ flag
164-
if (Imm & NVPTX::PTXCmpMode::FTZ_FLAG)
165-
O << ".ftz";
166-
return;
167-
} else if (Modifier == "base") {
168-
switch (Imm & NVPTX::PTXCmpMode::BASE_MASK) {
170+
if (Modifier == "FCmp") {
171+
switch (Imm) {
169172
default:
170173
return;
171174
case NVPTX::PTXCmpMode::EQ:
172-
O << ".eq";
175+
O << "eq";
173176
return;
174177
case NVPTX::PTXCmpMode::NE:
175-
O << ".ne";
178+
O << "ne";
176179
return;
177180
case NVPTX::PTXCmpMode::LT:
178-
O << ".lt";
181+
O << "lt";
179182
return;
180183
case NVPTX::PTXCmpMode::LE:
181-
O << ".le";
184+
O << "le";
182185
return;
183186
case NVPTX::PTXCmpMode::GT:
184-
O << ".gt";
187+
O << "gt";
185188
return;
186189
case NVPTX::PTXCmpMode::GE:
187-
O << ".ge";
188-
return;
189-
case NVPTX::PTXCmpMode::LO:
190-
O << ".lo";
191-
return;
192-
case NVPTX::PTXCmpMode::LS:
193-
O << ".ls";
194-
return;
195-
case NVPTX::PTXCmpMode::HI:
196-
O << ".hi";
197-
return;
198-
case NVPTX::PTXCmpMode::HS:
199-
O << ".hs";
190+
O << "ge";
200191
return;
201192
case NVPTX::PTXCmpMode::EQU:
202-
O << ".equ";
193+
O << "equ";
203194
return;
204195
case NVPTX::PTXCmpMode::NEU:
205-
O << ".neu";
196+
O << "neu";
206197
return;
207198
case NVPTX::PTXCmpMode::LTU:
208-
O << ".ltu";
199+
O << "ltu";
209200
return;
210201
case NVPTX::PTXCmpMode::LEU:
211-
O << ".leu";
202+
O << "leu";
212203
return;
213204
case NVPTX::PTXCmpMode::GTU:
214-
O << ".gtu";
205+
O << "gtu";
215206
return;
216207
case NVPTX::PTXCmpMode::GEU:
217-
O << ".geu";
208+
O << "geu";
218209
return;
219210
case NVPTX::PTXCmpMode::NUM:
220-
O << ".num";
211+
O << "num";
221212
return;
222213
case NVPTX::PTXCmpMode::NotANumber:
223-
O << ".nan";
214+
O << "nan";
215+
return;
216+
}
217+
}
218+
if (Modifier == "ICmp") {
219+
switch (Imm) {
220+
default:
221+
llvm_unreachable("Invalid ICmp mode");
222+
case NVPTX::PTXCmpMode::EQ:
223+
O << "eq";
224+
return;
225+
case NVPTX::PTXCmpMode::NE:
226+
O << "ne";
227+
return;
228+
case NVPTX::PTXCmpMode::LT:
229+
case NVPTX::PTXCmpMode::LTU:
230+
O << "lt";
231+
return;
232+
case NVPTX::PTXCmpMode::LE:
233+
case NVPTX::PTXCmpMode::LEU:
234+
O << "le";
235+
return;
236+
case NVPTX::PTXCmpMode::GT:
237+
case NVPTX::PTXCmpMode::GTU:
238+
O << "gt";
239+
return;
240+
case NVPTX::PTXCmpMode::GE:
241+
case NVPTX::PTXCmpMode::GEU:
242+
O << "ge";
243+
return;
244+
}
245+
}
246+
if (Modifier == "IType") {
247+
switch (Imm) {
248+
default:
249+
llvm_unreachable("Invalid IType");
250+
case NVPTX::PTXCmpMode::EQ:
251+
case NVPTX::PTXCmpMode::NE:
252+
O << "b";
253+
return;
254+
case NVPTX::PTXCmpMode::LT:
255+
case NVPTX::PTXCmpMode::LE:
256+
case NVPTX::PTXCmpMode::GT:
257+
case NVPTX::PTXCmpMode::GE:
258+
O << "s";
259+
return;
260+
case NVPTX::PTXCmpMode::LTU:
261+
case NVPTX::PTXCmpMode::LEU:
262+
case NVPTX::PTXCmpMode::GTU:
263+
case NVPTX::PTXCmpMode::GEU:
264+
O << "u";
224265
return;
225266
}
226267
}

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class NVPTXInstPrinter : public MCInstPrinter {
5454
void printCTAGroup(const MCInst *MI, int OpNum, raw_ostream &O);
5555
void printCallOperand(const MCInst *MI, int OpNum, raw_ostream &O,
5656
StringRef Modifier = {});
57+
void printFTZFlag(const MCInst *MI, int OpNum, raw_ostream &O);
5758
};
5859

5960
}

llvm/lib/Target/NVPTX/NVPTX.h

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
#ifndef LLVM_LIB_TARGET_NVPTX_NVPTX_H
1515
#define LLVM_LIB_TARGET_NVPTX_NVPTX_H
1616

17+
#include "llvm/CodeGen/ISDOpcodes.h"
1718
#include "llvm/IR/PassManager.h"
1819
#include "llvm/Pass.h"
1920
#include "llvm/Support/AtomicOrdering.h"
2021
#include "llvm/Support/CodeGen.h"
2122
#include "llvm/Target/TargetMachine.h"
22-
2323
namespace llvm {
2424
class FunctionPass;
2525
class MachineFunctionPass;
@@ -218,28 +218,21 @@ enum CvtMode {
218218
/// PTXCmpMode - Comparison mode enumeration
219219
namespace PTXCmpMode {
220220
enum CmpMode {
221-
EQ = 0,
222-
NE,
223-
LT,
224-
LE,
225-
GT,
226-
GE,
227-
LO,
228-
LS,
229-
HI,
230-
HS,
231-
EQU,
232-
NEU,
233-
LTU,
234-
LEU,
235-
GTU,
236-
GEU,
237-
NUM,
221+
EQ = ISD::SETEQ,
222+
NE = ISD::SETNE,
223+
LT = ISD::SETLT,
224+
LE = ISD::SETLE,
225+
GT = ISD::SETGT,
226+
GE = ISD::SETGE,
227+
EQU = ISD::SETUEQ,
228+
NEU = ISD::SETUNE,
229+
LTU = ISD::SETULT,
230+
LEU = ISD::SETULE,
231+
GTU = ISD::SETUGT,
232+
GEU = ISD::SETUGE,
233+
NUM = ISD::SETO,
238234
// NAN is a MACRO
239-
NotANumber,
240-
241-
BASE_MASK = 0xFF,
242-
FTZ_FLAG = 0x100
235+
NotANumber = ISD::SETUO,
243236
};
244237
}
245238

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -363,23 +363,29 @@ bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) {
363363

364364
// Map ISD:CONDCODE value to appropriate CmpMode expected by
365365
// NVPTXInstPrinter::printCmpMode()
366-
static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) {
366+
SDValue NVPTXDAGToDAGISel::getPTXCmpMode(const CondCodeSDNode &CondCode) {
367367
using NVPTX::PTXCmpMode::CmpMode;
368-
unsigned PTXCmpMode = [](ISD::CondCode CC) {
368+
const unsigned PTXCmpMode = [](ISD::CondCode CC) {
369369
switch (CC) {
370370
default:
371371
llvm_unreachable("Unexpected condition code.");
372372
case ISD::SETOEQ:
373+
case ISD::SETEQ:
373374
return CmpMode::EQ;
374375
case ISD::SETOGT:
376+
case ISD::SETGT:
375377
return CmpMode::GT;
376378
case ISD::SETOGE:
379+
case ISD::SETGE:
377380
return CmpMode::GE;
378381
case ISD::SETOLT:
382+
case ISD::SETLT:
379383
return CmpMode::LT;
380384
case ISD::SETOLE:
385+
case ISD::SETLE:
381386
return CmpMode::LE;
382387
case ISD::SETONE:
388+
case ISD::SETNE:
383389
return CmpMode::NE;
384390
case ISD::SETO:
385391
return CmpMode::NUM;
@@ -397,45 +403,29 @@ static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) {
397403
return CmpMode::LEU;
398404
case ISD::SETUNE:
399405
return CmpMode::NEU;
400-
case ISD::SETEQ:
401-
return CmpMode::EQ;
402-
case ISD::SETGT:
403-
return CmpMode::GT;
404-
case ISD::SETGE:
405-
return CmpMode::GE;
406-
case ISD::SETLT:
407-
return CmpMode::LT;
408-
case ISD::SETLE:
409-
return CmpMode::LE;
410-
case ISD::SETNE:
411-
return CmpMode::NE;
412406
}
413407
}(CondCode.get());
414-
415-
if (FTZ)
416-
PTXCmpMode |= NVPTX::PTXCmpMode::FTZ_FLAG;
417-
418-
return PTXCmpMode;
408+
return CurDAG->getTargetConstant(PTXCmpMode, SDLoc(), MVT::i32);
419409
}
420410

421411
bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
422-
unsigned PTXCmpMode =
423-
getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ());
412+
SDValue PTXCmpMode = getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)));
424413
SDLoc DL(N);
425414
SDNode *SetP = CurDAG->getMachineNode(
426-
NVPTX::SETP_f16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0),
427-
N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32));
415+
NVPTX::SETP_f16x2rr, DL, MVT::i1, MVT::i1,
416+
{N->getOperand(0), N->getOperand(1), PTXCmpMode,
417+
CurDAG->getTargetConstant(useF32FTZ() ? 1 : 0, DL, MVT::i1)});
428418
ReplaceNode(N, SetP);
429419
return true;
430420
}
431421

432422
bool NVPTXDAGToDAGISel::SelectSETP_BF16X2(SDNode *N) {
433-
unsigned PTXCmpMode =
434-
getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ());
423+
SDValue PTXCmpMode = getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)));
435424
SDLoc DL(N);
436425
SDNode *SetP = CurDAG->getMachineNode(
437-
NVPTX::SETP_bf16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0),
438-
N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32));
426+
NVPTX::SETP_bf16x2rr, DL, MVT::i1, MVT::i1,
427+
{N->getOperand(0), N->getOperand(1), PTXCmpMode,
428+
CurDAG->getTargetConstant(useF32FTZ() ? 1 : 0, DL, MVT::i1)});
439429
ReplaceNode(N, SetP);
440430
return true;
441431
}
@@ -1953,7 +1943,7 @@ bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) {
19531943
llvm_unreachable("Unexpected opcode");
19541944
};
19551945

1956-
int Opcode = IsVec ? NVPTX::BFMA16x2rrr : NVPTX::BFMA16rrr;
1946+
int Opcode = IsVec ? NVPTX::FMA_BF16x2rrr : NVPTX::FMA_BF16rrr;
19571947
MachineSDNode *FMA = CurDAG->getMachineNode(Opcode, DL, VT, Operands);
19581948
ReplaceNode(N, FMA);
19591949
return true;

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,11 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
104104
}
105105

106106
bool SelectADDR(SDValue Addr, SDValue &Base, SDValue &Offset);
107+
SDValue getPTXCmpMode(const CondCodeSDNode &CondCode);
107108
SDValue selectPossiblyImm(SDValue V);
108109

109110
bool ChkMemSDNodeAddressSpace(SDNode *N, unsigned int spN) const;
110111

111-
static unsigned GetConvertOpcode(MVT DestTy, MVT SrcTy, LoadSDNode *N);
112-
113112
// Returns the Memory Order and Scope that the PTX memory instruction should
114113
// use, and inserts appropriate fence instruction before the memory
115114
// instruction, if needed to implement the instructions memory order. Required

0 commit comments

Comments
 (0)