Skip to content

Commit df909bb

Browse files
committed
[NVPTX] Fixup v2i8 call lowering, use generic load/store nodes for call params
1 parent aa27d4e commit df909bb

40 files changed

+1684
-1693
lines changed

clang/test/CodeGenCUDA/bf16.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ __device__ __bf16 external_func( __bf16 in);
3535
// CHECK: .param .align 2 .b8 _Z9test_callDF16b_param_0[2]
3636
__device__ __bf16 test_call( __bf16 in) {
3737
// CHECK: ld.param.b16 %[[R:rs[0-9]+]], [_Z9test_callDF16b_param_0];
38-
// CHECK: st.param.b16 [param0], %[[R]];
3938
// CHECK: .param .align 2 .b8 retval0[2];
39+
// CHECK: st.param.b16 [param0], %[[R]];
4040
// CHECK: call.uni (retval0), _Z13external_funcDF16b, (param0);
4141
// CHECK: ld.param.b16 %[[RET:rs[0-9]+]], [retval0];
4242
return external_func(in);

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 0 additions & 273 deletions
Original file line numberDiff line numberDiff line change
@@ -145,18 +145,6 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
145145
if (tryStoreVector(N))
146146
return;
147147
break;
148-
case NVPTXISD::LoadParam:
149-
case NVPTXISD::LoadParamV2:
150-
case NVPTXISD::LoadParamV4:
151-
if (tryLoadParam(N))
152-
return;
153-
break;
154-
case NVPTXISD::StoreParam:
155-
case NVPTXISD::StoreParamV2:
156-
case NVPTXISD::StoreParamV4:
157-
if (tryStoreParam(N))
158-
return;
159-
break;
160148
case ISD::INTRINSIC_W_CHAIN:
161149
if (tryIntrinsicChain(N))
162150
return;
@@ -1419,267 +1407,6 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
14191407
return true;
14201408
}
14211409

1422-
bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
1423-
SDValue Chain = Node->getOperand(0);
1424-
SDValue Offset = Node->getOperand(2);
1425-
SDValue Glue = Node->getOperand(3);
1426-
SDLoc DL(Node);
1427-
MemSDNode *Mem = cast<MemSDNode>(Node);
1428-
1429-
unsigned VecSize;
1430-
switch (Node->getOpcode()) {
1431-
default:
1432-
return false;
1433-
case NVPTXISD::LoadParam:
1434-
VecSize = 1;
1435-
break;
1436-
case NVPTXISD::LoadParamV2:
1437-
VecSize = 2;
1438-
break;
1439-
case NVPTXISD::LoadParamV4:
1440-
VecSize = 4;
1441-
break;
1442-
}
1443-
1444-
EVT EltVT = Node->getValueType(0);
1445-
EVT MemVT = Mem->getMemoryVT();
1446-
1447-
std::optional<unsigned> Opcode;
1448-
1449-
switch (VecSize) {
1450-
default:
1451-
return false;
1452-
case 1:
1453-
Opcode = pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy,
1454-
NVPTX::LoadParamMemI8, NVPTX::LoadParamMemI16,
1455-
NVPTX::LoadParamMemI32, NVPTX::LoadParamMemI64);
1456-
break;
1457-
case 2:
1458-
Opcode =
1459-
pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy, NVPTX::LoadParamMemV2I8,
1460-
NVPTX::LoadParamMemV2I16, NVPTX::LoadParamMemV2I32,
1461-
NVPTX::LoadParamMemV2I64);
1462-
break;
1463-
case 4:
1464-
Opcode = pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy,
1465-
NVPTX::LoadParamMemV4I8, NVPTX::LoadParamMemV4I16,
1466-
NVPTX::LoadParamMemV4I32, {/* no v4i64 */});
1467-
break;
1468-
}
1469-
if (!Opcode)
1470-
return false;
1471-
1472-
SDVTList VTs;
1473-
if (VecSize == 1) {
1474-
VTs = CurDAG->getVTList(EltVT, MVT::Other, MVT::Glue);
1475-
} else if (VecSize == 2) {
1476-
VTs = CurDAG->getVTList(EltVT, EltVT, MVT::Other, MVT::Glue);
1477-
} else {
1478-
EVT EVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other, MVT::Glue };
1479-
VTs = CurDAG->getVTList(EVTs);
1480-
}
1481-
1482-
unsigned OffsetVal = Offset->getAsZExtVal();
1483-
1484-
SmallVector<SDValue, 2> Ops(
1485-
{CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain, Glue});
1486-
1487-
ReplaceNode(Node, CurDAG->getMachineNode(*Opcode, DL, VTs, Ops));
1488-
return true;
1489-
}
1490-
1491-
// Helpers for constructing opcode (ex: NVPTX::StoreParamV4F32_iiri)
1492-
#define getOpcV2H(ty, opKind0, opKind1) \
1493-
NVPTX::StoreParamV2##ty##_##opKind0##opKind1
1494-
1495-
#define getOpcV2H1(ty, opKind0, isImm1) \
1496-
(isImm1) ? getOpcV2H(ty, opKind0, i) : getOpcV2H(ty, opKind0, r)
1497-
1498-
#define getOpcodeForVectorStParamV2(ty, isimm) \
1499-
(isimm[0]) ? getOpcV2H1(ty, i, isimm[1]) : getOpcV2H1(ty, r, isimm[1])
1500-
1501-
#define getOpcV4H(ty, opKind0, opKind1, opKind2, opKind3) \
1502-
NVPTX::StoreParamV4##ty##_##opKind0##opKind1##opKind2##opKind3
1503-
1504-
#define getOpcV4H3(ty, opKind0, opKind1, opKind2, isImm3) \
1505-
(isImm3) ? getOpcV4H(ty, opKind0, opKind1, opKind2, i) \
1506-
: getOpcV4H(ty, opKind0, opKind1, opKind2, r)
1507-
1508-
#define getOpcV4H2(ty, opKind0, opKind1, isImm2, isImm3) \
1509-
(isImm2) ? getOpcV4H3(ty, opKind0, opKind1, i, isImm3) \
1510-
: getOpcV4H3(ty, opKind0, opKind1, r, isImm3)
1511-
1512-
#define getOpcV4H1(ty, opKind0, isImm1, isImm2, isImm3) \
1513-
(isImm1) ? getOpcV4H2(ty, opKind0, i, isImm2, isImm3) \
1514-
: getOpcV4H2(ty, opKind0, r, isImm2, isImm3)
1515-
1516-
#define getOpcodeForVectorStParamV4(ty, isimm) \
1517-
(isimm[0]) ? getOpcV4H1(ty, i, isimm[1], isimm[2], isimm[3]) \
1518-
: getOpcV4H1(ty, r, isimm[1], isimm[2], isimm[3])
1519-
1520-
#define getOpcodeForVectorStParam(n, ty, isimm) \
1521-
(n == 2) ? getOpcodeForVectorStParamV2(ty, isimm) \
1522-
: getOpcodeForVectorStParamV4(ty, isimm)
1523-
1524-
static unsigned pickOpcodeForVectorStParam(SmallVector<SDValue, 8> &Ops,
1525-
unsigned NumElts,
1526-
MVT::SimpleValueType MemTy,
1527-
SelectionDAG *CurDAG, SDLoc DL) {
1528-
// Determine which inputs are registers and immediates make new operators
1529-
// with constant values
1530-
SmallVector<bool, 4> IsImm(NumElts, false);
1531-
for (unsigned i = 0; i < NumElts; i++) {
1532-
IsImm[i] = (isa<ConstantSDNode>(Ops[i]) || isa<ConstantFPSDNode>(Ops[i]));
1533-
if (IsImm[i]) {
1534-
SDValue Imm = Ops[i];
1535-
if (MemTy == MVT::f32 || MemTy == MVT::f64) {
1536-
const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
1537-
const ConstantFP *CF = ConstImm->getConstantFPValue();
1538-
Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
1539-
} else {
1540-
const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
1541-
const ConstantInt *CI = ConstImm->getConstantIntValue();
1542-
Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
1543-
}
1544-
Ops[i] = Imm;
1545-
}
1546-
}
1547-
1548-
// Get opcode for MemTy, size, and register/immediate operand ordering
1549-
switch (MemTy) {
1550-
case MVT::i8:
1551-
return getOpcodeForVectorStParam(NumElts, I8, IsImm);
1552-
case MVT::i16:
1553-
return getOpcodeForVectorStParam(NumElts, I16, IsImm);
1554-
case MVT::i32:
1555-
return getOpcodeForVectorStParam(NumElts, I32, IsImm);
1556-
case MVT::i64:
1557-
assert(NumElts == 2 && "MVT too large for NumElts > 2");
1558-
return getOpcodeForVectorStParamV2(I64, IsImm);
1559-
case MVT::f32:
1560-
return getOpcodeForVectorStParam(NumElts, F32, IsImm);
1561-
case MVT::f64:
1562-
assert(NumElts == 2 && "MVT too large for NumElts > 2");
1563-
return getOpcodeForVectorStParamV2(F64, IsImm);
1564-
1565-
// These cases don't support immediates, just use the all register version
1566-
// and generate moves.
1567-
case MVT::i1:
1568-
return (NumElts == 2) ? NVPTX::StoreParamV2I8_rr
1569-
: NVPTX::StoreParamV4I8_rrrr;
1570-
case MVT::f16:
1571-
case MVT::bf16:
1572-
return (NumElts == 2) ? NVPTX::StoreParamV2I16_rr
1573-
: NVPTX::StoreParamV4I16_rrrr;
1574-
case MVT::v2f16:
1575-
case MVT::v2bf16:
1576-
case MVT::v2i16:
1577-
case MVT::v4i8:
1578-
return (NumElts == 2) ? NVPTX::StoreParamV2I32_rr
1579-
: NVPTX::StoreParamV4I32_rrrr;
1580-
default:
1581-
llvm_unreachable("Cannot select st.param for unknown MemTy");
1582-
}
1583-
}
1584-
1585-
bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
1586-
SDLoc DL(N);
1587-
SDValue Chain = N->getOperand(0);
1588-
SDValue Param = N->getOperand(1);
1589-
unsigned ParamVal = Param->getAsZExtVal();
1590-
SDValue Offset = N->getOperand(2);
1591-
unsigned OffsetVal = Offset->getAsZExtVal();
1592-
MemSDNode *Mem = cast<MemSDNode>(N);
1593-
SDValue Glue = N->getOperand(N->getNumOperands() - 1);
1594-
1595-
// How many elements do we have?
1596-
unsigned NumElts;
1597-
switch (N->getOpcode()) {
1598-
default:
1599-
llvm_unreachable("Unexpected opcode");
1600-
case NVPTXISD::StoreParam:
1601-
NumElts = 1;
1602-
break;
1603-
case NVPTXISD::StoreParamV2:
1604-
NumElts = 2;
1605-
break;
1606-
case NVPTXISD::StoreParamV4:
1607-
NumElts = 4;
1608-
break;
1609-
}
1610-
1611-
// Build vector of operands
1612-
SmallVector<SDValue, 8> Ops;
1613-
for (unsigned i = 0; i < NumElts; ++i)
1614-
Ops.push_back(N->getOperand(i + 3));
1615-
Ops.append({CurDAG->getTargetConstant(ParamVal, DL, MVT::i32),
1616-
CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain, Glue});
1617-
1618-
// Determine target opcode
1619-
// If we have an i1, use an 8-bit store. The lowering code in
1620-
// NVPTXISelLowering will have already emitted an upcast.
1621-
std::optional<unsigned> Opcode;
1622-
switch (NumElts) {
1623-
default:
1624-
llvm_unreachable("Unexpected NumElts");
1625-
case 1: {
1626-
MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
1627-
SDValue Imm = Ops[0];
1628-
if (MemTy != MVT::f16 && MemTy != MVT::bf16 &&
1629-
(isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
1630-
// Convert immediate to target constant
1631-
if (MemTy == MVT::f32 || MemTy == MVT::f64) {
1632-
const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
1633-
const ConstantFP *CF = ConstImm->getConstantFPValue();
1634-
Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
1635-
} else {
1636-
const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
1637-
const ConstantInt *CI = ConstImm->getConstantIntValue();
1638-
Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
1639-
}
1640-
Ops[0] = Imm;
1641-
// Use immediate version of store param
1642-
Opcode =
1643-
pickOpcodeForVT(MemTy, NVPTX::StoreParamI8_i, NVPTX::StoreParamI16_i,
1644-
NVPTX::StoreParamI32_i, NVPTX::StoreParamI64_i);
1645-
} else
1646-
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
1647-
NVPTX::StoreParamI8_r, NVPTX::StoreParamI16_r,
1648-
NVPTX::StoreParamI32_r, NVPTX::StoreParamI64_r);
1649-
if (Opcode == NVPTX::StoreParamI8_r) {
1650-
// Fine tune the opcode depending on the size of the operand.
1651-
// This helps to avoid creating redundant COPY instructions in
1652-
// InstrEmitter::AddRegisterOperand().
1653-
switch (Ops[0].getSimpleValueType().SimpleTy) {
1654-
default:
1655-
break;
1656-
case MVT::i32:
1657-
Opcode = NVPTX::StoreParamI8TruncI32_r;
1658-
break;
1659-
case MVT::i64:
1660-
Opcode = NVPTX::StoreParamI8TruncI64_r;
1661-
break;
1662-
}
1663-
}
1664-
break;
1665-
}
1666-
case 2:
1667-
case 4: {
1668-
MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
1669-
Opcode = pickOpcodeForVectorStParam(Ops, NumElts, MemTy, CurDAG, DL);
1670-
break;
1671-
}
1672-
}
1673-
1674-
SDVTList RetVTs = CurDAG->getVTList(MVT::Other, MVT::Glue);
1675-
SDNode *Ret = CurDAG->getMachineNode(*Opcode, DL, RetVTs, Ops);
1676-
MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
1677-
CurDAG->setNodeMemRefs(cast<MachineSDNode>(Ret), {MemRef});
1678-
1679-
ReplaceNode(N, Ret);
1680-
return true;
1681-
}
1682-
16831410
/// SelectBFE - Look for instruction sequences that can be made more efficient
16841411
/// by using the 'bfe' (bit-field extract) PTX instruction
16851412
bool NVPTXDAGToDAGISel::tryBFE(SDNode *N) {

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
7878
bool tryLDG(MemSDNode *N);
7979
bool tryStore(SDNode *N);
8080
bool tryStoreVector(SDNode *N);
81-
bool tryLoadParam(SDNode *N);
82-
bool tryStoreParam(SDNode *N);
8381
bool tryFence(SDNode *N);
8482
void SelectAddrSpaceCast(SDNode *N);
8583
bool tryBFE(SDNode *N);

0 commit comments

Comments
 (0)