Skip to content

[NVPTX] Rework and cleanup FTZ ISel #146410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 74 additions & 33 deletions llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,73 +154,114 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
llvm_unreachable("Invalid conversion modifier");
}

void NVPTXInstPrinter::printFTZFlag(const MCInst *MI, int OpNum,
raw_ostream &O) {
const MCOperand &MO = MI->getOperand(OpNum);
const int Imm = MO.getImm();
if (Imm)
O << ".ftz";
}

void NVPTXInstPrinter::printCmpMode(const MCInst *MI, int OpNum, raw_ostream &O,
StringRef Modifier) {
const MCOperand &MO = MI->getOperand(OpNum);
int64_t Imm = MO.getImm();

if (Modifier == "ftz") {
// FTZ flag
if (Imm & NVPTX::PTXCmpMode::FTZ_FLAG)
O << ".ftz";
return;
} else if (Modifier == "base") {
switch (Imm & NVPTX::PTXCmpMode::BASE_MASK) {
if (Modifier == "FCmp") {
switch (Imm) {
default:
return;
case NVPTX::PTXCmpMode::EQ:
O << ".eq";
O << "eq";
return;
case NVPTX::PTXCmpMode::NE:
O << ".ne";
O << "ne";
return;
case NVPTX::PTXCmpMode::LT:
O << ".lt";
O << "lt";
return;
case NVPTX::PTXCmpMode::LE:
O << ".le";
O << "le";
return;
case NVPTX::PTXCmpMode::GT:
O << ".gt";
O << "gt";
return;
case NVPTX::PTXCmpMode::GE:
O << ".ge";
return;
case NVPTX::PTXCmpMode::LO:
O << ".lo";
return;
case NVPTX::PTXCmpMode::LS:
O << ".ls";
return;
case NVPTX::PTXCmpMode::HI:
O << ".hi";
return;
case NVPTX::PTXCmpMode::HS:
O << ".hs";
O << "ge";
return;
case NVPTX::PTXCmpMode::EQU:
O << ".equ";
O << "equ";
return;
case NVPTX::PTXCmpMode::NEU:
O << ".neu";
O << "neu";
return;
case NVPTX::PTXCmpMode::LTU:
O << ".ltu";
O << "ltu";
return;
case NVPTX::PTXCmpMode::LEU:
O << ".leu";
O << "leu";
return;
case NVPTX::PTXCmpMode::GTU:
O << ".gtu";
O << "gtu";
return;
case NVPTX::PTXCmpMode::GEU:
O << ".geu";
O << "geu";
return;
case NVPTX::PTXCmpMode::NUM:
O << ".num";
O << "num";
return;
case NVPTX::PTXCmpMode::NotANumber:
O << ".nan";
O << "nan";
return;
}
}
if (Modifier == "ICmp") {
switch (Imm) {
default:
llvm_unreachable("Invalid ICmp mode");
case NVPTX::PTXCmpMode::EQ:
O << "eq";
return;
case NVPTX::PTXCmpMode::NE:
O << "ne";
return;
case NVPTX::PTXCmpMode::LT:
case NVPTX::PTXCmpMode::LTU:
O << "lt";
return;
case NVPTX::PTXCmpMode::LE:
case NVPTX::PTXCmpMode::LEU:
O << "le";
return;
case NVPTX::PTXCmpMode::GT:
case NVPTX::PTXCmpMode::GTU:
O << "gt";
return;
case NVPTX::PTXCmpMode::GE:
case NVPTX::PTXCmpMode::GEU:
O << "ge";
return;
}
}
if (Modifier == "IType") {
switch (Imm) {
default:
llvm_unreachable("Invalid IType");
case NVPTX::PTXCmpMode::EQ:
case NVPTX::PTXCmpMode::NE:
O << "b";
return;
case NVPTX::PTXCmpMode::LT:
case NVPTX::PTXCmpMode::LE:
case NVPTX::PTXCmpMode::GT:
case NVPTX::PTXCmpMode::GE:
O << "s";
return;
case NVPTX::PTXCmpMode::LTU:
case NVPTX::PTXCmpMode::LEU:
case NVPTX::PTXCmpMode::GTU:
case NVPTX::PTXCmpMode::GEU:
O << "u";
return;
}
}
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class NVPTXInstPrinter : public MCInstPrinter {
void printCTAGroup(const MCInst *MI, int OpNum, raw_ostream &O);
void printCallOperand(const MCInst *MI, int OpNum, raw_ostream &O,
StringRef Modifier = {});
void printFTZFlag(const MCInst *MI, int OpNum, raw_ostream &O);
};

}
Expand Down
9 changes: 1 addition & 8 deletions llvm/lib/Target/NVPTX/NVPTX.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
#ifndef LLVM_LIB_TARGET_NVPTX_NVPTX_H
#define LLVM_LIB_TARGET_NVPTX_NVPTX_H

#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
#include "llvm/Support/AtomicOrdering.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Target/TargetMachine.h"

namespace llvm {
class FunctionPass;
class MachineFunctionPass;
Expand Down Expand Up @@ -224,10 +224,6 @@ enum CmpMode {
LE,
GT,
GE,
LO,
LS,
HI,
HS,
EQU,
NEU,
LTU,
Expand All @@ -237,9 +233,6 @@ enum CmpMode {
NUM,
// NAN is a MACRO
NotANumber,

BASE_MASK = 0xFF,
FTZ_FLAG = 0x100
};
}

Expand Down
46 changes: 18 additions & 28 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,23 +363,29 @@ bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) {

// Map ISD:CONDCODE value to appropriate CmpMode expected by
// NVPTXInstPrinter::printCmpMode()
static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) {
SDValue NVPTXDAGToDAGISel::getPTXCmpMode(const CondCodeSDNode &CondCode) {
using NVPTX::PTXCmpMode::CmpMode;
unsigned PTXCmpMode = [](ISD::CondCode CC) {
const unsigned PTXCmpMode = [](ISD::CondCode CC) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... and we already have the ISD->CmpMode mapping helper function.

And it appears that ISD:CmpMode mapping is not 1:1, but rather N:1, so using IDS values for CmpMode enums is probably not buying us much -- we still need a helper function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above

switch (CC) {
default:
llvm_unreachable("Unexpected condition code.");
case ISD::SETOEQ:
case ISD::SETEQ:
return CmpMode::EQ;
case ISD::SETOGT:
case ISD::SETGT:
return CmpMode::GT;
case ISD::SETOGE:
case ISD::SETGE:
return CmpMode::GE;
case ISD::SETOLT:
case ISD::SETLT:
return CmpMode::LT;
case ISD::SETOLE:
case ISD::SETLE:
return CmpMode::LE;
case ISD::SETONE:
case ISD::SETNE:
return CmpMode::NE;
case ISD::SETO:
return CmpMode::NUM;
Expand All @@ -397,45 +403,29 @@ static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) {
return CmpMode::LEU;
case ISD::SETUNE:
return CmpMode::NEU;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given how you've defined CmpMode in terms of SDNode opcodes, couldn't you just return the opcode for these cases?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per the below discussion I've changed this back.

case ISD::SETEQ:
return CmpMode::EQ;
case ISD::SETGT:
return CmpMode::GT;
case ISD::SETGE:
return CmpMode::GE;
case ISD::SETLT:
return CmpMode::LT;
case ISD::SETLE:
return CmpMode::LE;
case ISD::SETNE:
return CmpMode::NE;
}
}(CondCode.get());

if (FTZ)
PTXCmpMode |= NVPTX::PTXCmpMode::FTZ_FLAG;

return PTXCmpMode;
return CurDAG->getTargetConstant(PTXCmpMode, SDLoc(), MVT::i32);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return CurDAG->getTargetConstant(PTXCmpMode, SDLoc(), MVT::i32);
return CurDAG->getTargetConstant(PTXCmpMode, SDLoc(&CondCode), MVT::i32);

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we know that this is just going to be an operand within another instructions there is no point adding a debug-loc.

}

bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
unsigned PTXCmpMode =
getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ());
SDValue PTXCmpMode = getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)));
SDLoc DL(N);
SDNode *SetP = CurDAG->getMachineNode(
NVPTX::SETP_f16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0),
N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32));
NVPTX::SETP_f16x2rr, DL, MVT::i1, MVT::i1,
{N->getOperand(0), N->getOperand(1), PTXCmpMode,
CurDAG->getTargetConstant(useF32FTZ() ? 1 : 0, DL, MVT::i1)});
ReplaceNode(N, SetP);
return true;
}

bool NVPTXDAGToDAGISel::SelectSETP_BF16X2(SDNode *N) {
unsigned PTXCmpMode =
getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ());
SDValue PTXCmpMode = getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)));
SDLoc DL(N);
SDNode *SetP = CurDAG->getMachineNode(
NVPTX::SETP_bf16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0),
N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32));
NVPTX::SETP_bf16x2rr, DL, MVT::i1, MVT::i1,
{N->getOperand(0), N->getOperand(1), PTXCmpMode,
CurDAG->getTargetConstant(useF32FTZ() ? 1 : 0, DL, MVT::i1)});
ReplaceNode(N, SetP);
return true;
}
Expand Down Expand Up @@ -1953,7 +1943,7 @@ bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) {
llvm_unreachable("Unexpected opcode");
};

int Opcode = IsVec ? NVPTX::BFMA16x2rrr : NVPTX::BFMA16rrr;
int Opcode = IsVec ? NVPTX::FMA_BF16x2rrr : NVPTX::FMA_BF16rrr;
MachineSDNode *FMA = CurDAG->getMachineNode(Opcode, DL, VT, Operands);
ReplaceNode(N, FMA);
return true;
Expand Down
3 changes: 1 addition & 2 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,11 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
}

bool SelectADDR(SDValue Addr, SDValue &Base, SDValue &Offset);
SDValue getPTXCmpMode(const CondCodeSDNode &CondCode);
SDValue selectPossiblyImm(SDValue V);

bool ChkMemSDNodeAddressSpace(SDNode *N, unsigned int spN) const;

static unsigned GetConvertOpcode(MVT DestTy, MVT SrcTy, LoadSDNode *N);

// Returns the Memory Order and Scope that the PTX memory instruction should
// use, and inserts appropriate fence instruction before the memory
// instruction, if needed to implement the instructions memory order. Required
Expand Down
Loading