Skip to content

Commit a33744a

Browse files
committed
[NVPTX] Fixup v2i8 call lowering, use generic load/store nodes for call params
1 parent a258870 commit a33744a

38 files changed

+1676
-1681
lines changed

clang/test/CodeGenCUDA/bf16.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ __device__ __bf16 external_func( __bf16 in);
3535
// CHECK: .param .align 2 .b8 _Z9test_callDF16b_param_0[2]
3636
__device__ __bf16 test_call( __bf16 in) {
3737
// CHECK: ld.param.b16 %[[R:rs[0-9]+]], [_Z9test_callDF16b_param_0];
38-
// CHECK: st.param.b16 [param0], %[[R]];
3938
// CHECK: .param .align 2 .b8 retval0[2];
39+
// CHECK: st.param.b16 [param0], %[[R]];
4040
// CHECK: call.uni (retval0), _Z13external_funcDF16b, (param0);
4141
// CHECK: ld.param.b16 %[[RET:rs[0-9]+]], [retval0];
4242
return external_func(in);

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 0 additions & 273 deletions
Original file line numberDiff line numberDiff line change
@@ -145,18 +145,6 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
145145
if (tryStoreVector(N))
146146
return;
147147
break;
148-
case NVPTXISD::LoadParam:
149-
case NVPTXISD::LoadParamV2:
150-
case NVPTXISD::LoadParamV4:
151-
if (tryLoadParam(N))
152-
return;
153-
break;
154-
case NVPTXISD::StoreParam:
155-
case NVPTXISD::StoreParamV2:
156-
case NVPTXISD::StoreParamV4:
157-
if (tryStoreParam(N))
158-
return;
159-
break;
160148
case ISD::INTRINSIC_W_CHAIN:
161149
if (tryIntrinsicChain(N))
162150
return;
@@ -1429,267 +1417,6 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
14291417
return true;
14301418
}
14311419

1432-
bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
1433-
SDValue Chain = Node->getOperand(0);
1434-
SDValue Offset = Node->getOperand(2);
1435-
SDValue Glue = Node->getOperand(3);
1436-
SDLoc DL(Node);
1437-
MemSDNode *Mem = cast<MemSDNode>(Node);
1438-
1439-
unsigned VecSize;
1440-
switch (Node->getOpcode()) {
1441-
default:
1442-
return false;
1443-
case NVPTXISD::LoadParam:
1444-
VecSize = 1;
1445-
break;
1446-
case NVPTXISD::LoadParamV2:
1447-
VecSize = 2;
1448-
break;
1449-
case NVPTXISD::LoadParamV4:
1450-
VecSize = 4;
1451-
break;
1452-
}
1453-
1454-
EVT EltVT = Node->getValueType(0);
1455-
EVT MemVT = Mem->getMemoryVT();
1456-
1457-
std::optional<unsigned> Opcode;
1458-
1459-
switch (VecSize) {
1460-
default:
1461-
return false;
1462-
case 1:
1463-
Opcode = pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy,
1464-
NVPTX::LoadParamMemI8, NVPTX::LoadParamMemI16,
1465-
NVPTX::LoadParamMemI32, NVPTX::LoadParamMemI64);
1466-
break;
1467-
case 2:
1468-
Opcode =
1469-
pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy, NVPTX::LoadParamMemV2I8,
1470-
NVPTX::LoadParamMemV2I16, NVPTX::LoadParamMemV2I32,
1471-
NVPTX::LoadParamMemV2I64);
1472-
break;
1473-
case 4:
1474-
Opcode = pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy,
1475-
NVPTX::LoadParamMemV4I8, NVPTX::LoadParamMemV4I16,
1476-
NVPTX::LoadParamMemV4I32, {/* no v4i64 */});
1477-
break;
1478-
}
1479-
if (!Opcode)
1480-
return false;
1481-
1482-
SDVTList VTs;
1483-
if (VecSize == 1) {
1484-
VTs = CurDAG->getVTList(EltVT, MVT::Other, MVT::Glue);
1485-
} else if (VecSize == 2) {
1486-
VTs = CurDAG->getVTList(EltVT, EltVT, MVT::Other, MVT::Glue);
1487-
} else {
1488-
EVT EVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other, MVT::Glue };
1489-
VTs = CurDAG->getVTList(EVTs);
1490-
}
1491-
1492-
unsigned OffsetVal = Offset->getAsZExtVal();
1493-
1494-
SmallVector<SDValue, 2> Ops(
1495-
{CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain, Glue});
1496-
1497-
ReplaceNode(Node, CurDAG->getMachineNode(*Opcode, DL, VTs, Ops));
1498-
return true;
1499-
}
1500-
1501-
// Helpers for constructing opcode (ex: NVPTX::StoreParamV4F32_iiri)
1502-
#define getOpcV2H(ty, opKind0, opKind1) \
1503-
NVPTX::StoreParamV2##ty##_##opKind0##opKind1
1504-
1505-
#define getOpcV2H1(ty, opKind0, isImm1) \
1506-
(isImm1) ? getOpcV2H(ty, opKind0, i) : getOpcV2H(ty, opKind0, r)
1507-
1508-
#define getOpcodeForVectorStParamV2(ty, isimm) \
1509-
(isimm[0]) ? getOpcV2H1(ty, i, isimm[1]) : getOpcV2H1(ty, r, isimm[1])
1510-
1511-
#define getOpcV4H(ty, opKind0, opKind1, opKind2, opKind3) \
1512-
NVPTX::StoreParamV4##ty##_##opKind0##opKind1##opKind2##opKind3
1513-
1514-
#define getOpcV4H3(ty, opKind0, opKind1, opKind2, isImm3) \
1515-
(isImm3) ? getOpcV4H(ty, opKind0, opKind1, opKind2, i) \
1516-
: getOpcV4H(ty, opKind0, opKind1, opKind2, r)
1517-
1518-
#define getOpcV4H2(ty, opKind0, opKind1, isImm2, isImm3) \
1519-
(isImm2) ? getOpcV4H3(ty, opKind0, opKind1, i, isImm3) \
1520-
: getOpcV4H3(ty, opKind0, opKind1, r, isImm3)
1521-
1522-
#define getOpcV4H1(ty, opKind0, isImm1, isImm2, isImm3) \
1523-
(isImm1) ? getOpcV4H2(ty, opKind0, i, isImm2, isImm3) \
1524-
: getOpcV4H2(ty, opKind0, r, isImm2, isImm3)
1525-
1526-
#define getOpcodeForVectorStParamV4(ty, isimm) \
1527-
(isimm[0]) ? getOpcV4H1(ty, i, isimm[1], isimm[2], isimm[3]) \
1528-
: getOpcV4H1(ty, r, isimm[1], isimm[2], isimm[3])
1529-
1530-
#define getOpcodeForVectorStParam(n, ty, isimm) \
1531-
(n == 2) ? getOpcodeForVectorStParamV2(ty, isimm) \
1532-
: getOpcodeForVectorStParamV4(ty, isimm)
1533-
1534-
static unsigned pickOpcodeForVectorStParam(SmallVector<SDValue, 8> &Ops,
1535-
unsigned NumElts,
1536-
MVT::SimpleValueType MemTy,
1537-
SelectionDAG *CurDAG, SDLoc DL) {
1538-
// Determine which inputs are registers and immediates make new operators
1539-
// with constant values
1540-
SmallVector<bool, 4> IsImm(NumElts, false);
1541-
for (unsigned i = 0; i < NumElts; i++) {
1542-
IsImm[i] = (isa<ConstantSDNode>(Ops[i]) || isa<ConstantFPSDNode>(Ops[i]));
1543-
if (IsImm[i]) {
1544-
SDValue Imm = Ops[i];
1545-
if (MemTy == MVT::f32 || MemTy == MVT::f64) {
1546-
const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
1547-
const ConstantFP *CF = ConstImm->getConstantFPValue();
1548-
Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
1549-
} else {
1550-
const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
1551-
const ConstantInt *CI = ConstImm->getConstantIntValue();
1552-
Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
1553-
}
1554-
Ops[i] = Imm;
1555-
}
1556-
}
1557-
1558-
// Get opcode for MemTy, size, and register/immediate operand ordering
1559-
switch (MemTy) {
1560-
case MVT::i8:
1561-
return getOpcodeForVectorStParam(NumElts, I8, IsImm);
1562-
case MVT::i16:
1563-
return getOpcodeForVectorStParam(NumElts, I16, IsImm);
1564-
case MVT::i32:
1565-
return getOpcodeForVectorStParam(NumElts, I32, IsImm);
1566-
case MVT::i64:
1567-
assert(NumElts == 2 && "MVT too large for NumElts > 2");
1568-
return getOpcodeForVectorStParamV2(I64, IsImm);
1569-
case MVT::f32:
1570-
return getOpcodeForVectorStParam(NumElts, F32, IsImm);
1571-
case MVT::f64:
1572-
assert(NumElts == 2 && "MVT too large for NumElts > 2");
1573-
return getOpcodeForVectorStParamV2(F64, IsImm);
1574-
1575-
// These cases don't support immediates, just use the all register version
1576-
// and generate moves.
1577-
case MVT::i1:
1578-
return (NumElts == 2) ? NVPTX::StoreParamV2I8_rr
1579-
: NVPTX::StoreParamV4I8_rrrr;
1580-
case MVT::f16:
1581-
case MVT::bf16:
1582-
return (NumElts == 2) ? NVPTX::StoreParamV2I16_rr
1583-
: NVPTX::StoreParamV4I16_rrrr;
1584-
case MVT::v2f16:
1585-
case MVT::v2bf16:
1586-
case MVT::v2i16:
1587-
case MVT::v4i8:
1588-
return (NumElts == 2) ? NVPTX::StoreParamV2I32_rr
1589-
: NVPTX::StoreParamV4I32_rrrr;
1590-
default:
1591-
llvm_unreachable("Cannot select st.param for unknown MemTy");
1592-
}
1593-
}
1594-
1595-
bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
1596-
SDLoc DL(N);
1597-
SDValue Chain = N->getOperand(0);
1598-
SDValue Param = N->getOperand(1);
1599-
unsigned ParamVal = Param->getAsZExtVal();
1600-
SDValue Offset = N->getOperand(2);
1601-
unsigned OffsetVal = Offset->getAsZExtVal();
1602-
MemSDNode *Mem = cast<MemSDNode>(N);
1603-
SDValue Glue = N->getOperand(N->getNumOperands() - 1);
1604-
1605-
// How many elements do we have?
1606-
unsigned NumElts;
1607-
switch (N->getOpcode()) {
1608-
default:
1609-
llvm_unreachable("Unexpected opcode");
1610-
case NVPTXISD::StoreParam:
1611-
NumElts = 1;
1612-
break;
1613-
case NVPTXISD::StoreParamV2:
1614-
NumElts = 2;
1615-
break;
1616-
case NVPTXISD::StoreParamV4:
1617-
NumElts = 4;
1618-
break;
1619-
}
1620-
1621-
// Build vector of operands
1622-
SmallVector<SDValue, 8> Ops;
1623-
for (unsigned i = 0; i < NumElts; ++i)
1624-
Ops.push_back(N->getOperand(i + 3));
1625-
Ops.append({CurDAG->getTargetConstant(ParamVal, DL, MVT::i32),
1626-
CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain, Glue});
1627-
1628-
// Determine target opcode
1629-
// If we have an i1, use an 8-bit store. The lowering code in
1630-
// NVPTXISelLowering will have already emitted an upcast.
1631-
std::optional<unsigned> Opcode;
1632-
switch (NumElts) {
1633-
default:
1634-
llvm_unreachable("Unexpected NumElts");
1635-
case 1: {
1636-
MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
1637-
SDValue Imm = Ops[0];
1638-
if (MemTy != MVT::f16 && MemTy != MVT::bf16 &&
1639-
(isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
1640-
// Convert immediate to target constant
1641-
if (MemTy == MVT::f32 || MemTy == MVT::f64) {
1642-
const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
1643-
const ConstantFP *CF = ConstImm->getConstantFPValue();
1644-
Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
1645-
} else {
1646-
const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
1647-
const ConstantInt *CI = ConstImm->getConstantIntValue();
1648-
Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
1649-
}
1650-
Ops[0] = Imm;
1651-
// Use immediate version of store param
1652-
Opcode =
1653-
pickOpcodeForVT(MemTy, NVPTX::StoreParamI8_i, NVPTX::StoreParamI16_i,
1654-
NVPTX::StoreParamI32_i, NVPTX::StoreParamI64_i);
1655-
} else
1656-
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
1657-
NVPTX::StoreParamI8_r, NVPTX::StoreParamI16_r,
1658-
NVPTX::StoreParamI32_r, NVPTX::StoreParamI64_r);
1659-
if (Opcode == NVPTX::StoreParamI8_r) {
1660-
// Fine tune the opcode depending on the size of the operand.
1661-
// This helps to avoid creating redundant COPY instructions in
1662-
// InstrEmitter::AddRegisterOperand().
1663-
switch (Ops[0].getSimpleValueType().SimpleTy) {
1664-
default:
1665-
break;
1666-
case MVT::i32:
1667-
Opcode = NVPTX::StoreParamI8TruncI32_r;
1668-
break;
1669-
case MVT::i64:
1670-
Opcode = NVPTX::StoreParamI8TruncI64_r;
1671-
break;
1672-
}
1673-
}
1674-
break;
1675-
}
1676-
case 2:
1677-
case 4: {
1678-
MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
1679-
Opcode = pickOpcodeForVectorStParam(Ops, NumElts, MemTy, CurDAG, DL);
1680-
break;
1681-
}
1682-
}
1683-
1684-
SDVTList RetVTs = CurDAG->getVTList(MVT::Other, MVT::Glue);
1685-
SDNode *Ret = CurDAG->getMachineNode(*Opcode, DL, RetVTs, Ops);
1686-
MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
1687-
CurDAG->setNodeMemRefs(cast<MachineSDNode>(Ret), {MemRef});
1688-
1689-
ReplaceNode(N, Ret);
1690-
return true;
1691-
}
1692-
16931420
/// SelectBFE - Look for instruction sequences that can be made more efficient
16941421
/// by using the 'bfe' (bit-field extract) PTX instruction
16951422
bool NVPTXDAGToDAGISel::tryBFE(SDNode *N) {

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
7878
bool tryLDG(MemSDNode *N);
7979
bool tryStore(SDNode *N);
8080
bool tryStoreVector(SDNode *N);
81-
bool tryLoadParam(SDNode *N);
82-
bool tryStoreParam(SDNode *N);
8381
bool tryFence(SDNode *N);
8482
void SelectAddrSpaceCast(SDNode *N);
8583
bool tryBFE(SDNode *N);

0 commit comments

Comments
 (0)