Skip to content

Commit d37eb0f

Browse files
committed
[NVPTX] loads, stores of v2f32 are untyped
Ensures ld.b64 and st.b64 for v2f32. Also remove -O3 in f32x2-instructions.ll test.
1 parent 24f7f11 commit d37eb0f

File tree

2 files changed

+1037
-2100
lines changed

2 files changed

+1037
-2100
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,6 +1074,7 @@ static int getLdStRegType(EVT VT) {
10741074
case MVT::bf16:
10751075
case MVT::v2f16:
10761076
case MVT::v2bf16:
1077+
case MVT::v2f32:
10771078
return NVPTX::PTXLdStInstCode::Untyped;
10781079
default:
10791080
return NVPTX::PTXLdStInstCode::Float;
@@ -1113,24 +1114,27 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
11131114
// Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
11141115
MVT SimpleVT = LoadedVT.getSimpleVT();
11151116
MVT ScalarVT = SimpleVT.getScalarType();
1116-
// Read at least 8 bits (predicates are stored as 8-bit values)
1117-
unsigned FromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits());
1118-
unsigned int FromType;
11191117

11201118
// Vector Setting
11211119
unsigned VecType = NVPTX::PTXLdStInstCode::Scalar;
11221120
if (SimpleVT.isVector()) {
1123-
if (Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8)
1124-
// v2f16/v2bf16/v2i16 is loaded using ld.b32
1125-
FromTypeWidth = 32;
1126-
else if (LoadedVT == MVT::v2f32)
1127-
// v2f32 is loaded using ld.b64
1128-
FromTypeWidth = 64;
1129-
else
1130-
llvm_unreachable("Unexpected vector type");
1121+
switch (LoadedVT.getSimpleVT().SimpleTy) {
1122+
case MVT::v2f16:
1123+
case MVT::v2bf16:
1124+
case MVT::v2i16:
1125+
case MVT::v4i8:
1126+
case MVT::v2f32:
1127+
ScalarVT = LoadedVT.getSimpleVT();
1128+
break;
1129+
default:
1130+
llvm_unreachable("Unsupported vector type for non-vector load");
1131+
}
11311132
}
11321133

1133-
if (PlainLoad && (PlainLoad->getExtensionType() == ISD::SEXTLOAD))
1134+
// Read at least 8 bits (predicates are stored as 8-bit values)
1135+
unsigned FromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits());
1136+
unsigned int FromType;
1137+
if (PlainLoad && PlainLoad->getExtensionType() == ISD::SEXTLOAD)
11341138
FromType = NVPTX::PTXLdStInstCode::Signed;
11351139
else
11361140
FromType = getLdStRegType(ScalarVT);
@@ -1438,18 +1442,21 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
14381442
// Type Setting: toType + toTypeWidth
14391443
// - for integer type, always use 'u'
14401444
MVT ScalarVT = SimpleVT.getScalarType();
1441-
unsigned ToTypeWidth = ScalarVT.getSizeInBits();
14421445
if (SimpleVT.isVector()) {
1443-
if (Isv2x16VT(StoreVT) || StoreVT == MVT::v4i8)
1444-
// v2x16 is stored using st.b32
1445-
ToTypeWidth = 32;
1446-
else if (StoreVT == MVT::v2f32)
1447-
// v2f32 is stored using st.b64
1448-
ToTypeWidth = 64;
1449-
else
1450-
llvm_unreachable("Unexpected vector type");
1446+
switch (StoreVT.getSimpleVT().SimpleTy) {
1447+
case MVT::v2f16:
1448+
case MVT::v2bf16:
1449+
case MVT::v2i16:
1450+
case MVT::v4i8:
1451+
case MVT::v2f32:
1452+
ScalarVT = StoreVT.getSimpleVT();
1453+
break;
1454+
default:
1455+
llvm_unreachable("Unsupported vector type for non-vector store");
1456+
}
14511457
}
14521458

1459+
unsigned ToTypeWidth = ScalarVT.getSizeInBits();
14531460
unsigned int ToType = getLdStRegType(ScalarVT);
14541461

14551462
// Create the machine instruction DAG

0 commit comments

Comments
 (0)