Skip to content

Commit a6eed92

Browse files
committed
[NVPTX] loads, stores of v2f32 are untyped
Ensures ld.b64 and st.b64 for v2f32. Also remove -O3 in f32x2-instructions.ll test.
1 parent db55ccb commit a6eed92

File tree

2 files changed

+1037
-2100
lines changed

2 files changed

+1037
-2100
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,6 +1060,7 @@ static int getLdStRegType(EVT VT) {
10601060
case MVT::bf16:
10611061
case MVT::v2f16:
10621062
case MVT::v2bf16:
1063+
case MVT::v2f32:
10631064
return NVPTX::PTXLdStInstCode::Untyped;
10641065
default:
10651066
return NVPTX::PTXLdStInstCode::Float;
@@ -1099,24 +1100,27 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
10991100
// Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
11001101
MVT SimpleVT = LoadedVT.getSimpleVT();
11011102
MVT ScalarVT = SimpleVT.getScalarType();
1102-
// Read at least 8 bits (predicates are stored as 8-bit values)
1103-
unsigned FromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits());
1104-
unsigned int FromType;
11051103

11061104
// Vector Setting
11071105
unsigned VecType = NVPTX::PTXLdStInstCode::Scalar;
11081106
if (SimpleVT.isVector()) {
1109-
if (Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8)
1110-
// v2f16/v2bf16/v2i16 is loaded using ld.b32
1111-
FromTypeWidth = 32;
1112-
else if (LoadedVT == MVT::v2f32)
1113-
// v2f32 is loaded using ld.b64
1114-
FromTypeWidth = 64;
1115-
else
1116-
llvm_unreachable("Unexpected vector type");
1107+
switch (LoadedVT.getSimpleVT().SimpleTy) {
1108+
case MVT::v2f16:
1109+
case MVT::v2bf16:
1110+
case MVT::v2i16:
1111+
case MVT::v4i8:
1112+
case MVT::v2f32:
1113+
ScalarVT = LoadedVT.getSimpleVT();
1114+
break;
1115+
default:
1116+
llvm_unreachable("Unsupported vector type for non-vector load");
1117+
}
11171118
}
11181119

1119-
if (PlainLoad && (PlainLoad->getExtensionType() == ISD::SEXTLOAD))
1120+
// Read at least 8 bits (predicates are stored as 8-bit values)
1121+
unsigned FromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits());
1122+
unsigned int FromType;
1123+
if (PlainLoad && PlainLoad->getExtensionType() == ISD::SEXTLOAD)
11201124
FromType = NVPTX::PTXLdStInstCode::Signed;
11211125
else
11221126
FromType = getLdStRegType(ScalarVT);
@@ -1424,18 +1428,21 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
14241428
// Type Setting: toType + toTypeWidth
14251429
// - for integer type, always use 'u'
14261430
MVT ScalarVT = SimpleVT.getScalarType();
1427-
unsigned ToTypeWidth = ScalarVT.getSizeInBits();
14281431
if (SimpleVT.isVector()) {
1429-
if (Isv2x16VT(StoreVT) || StoreVT == MVT::v4i8)
1430-
// v2x16 is stored using st.b32
1431-
ToTypeWidth = 32;
1432-
else if (StoreVT == MVT::v2f32)
1433-
// v2f32 is stored using st.b64
1434-
ToTypeWidth = 64;
1435-
else
1436-
llvm_unreachable("Unexpected vector type");
1432+
switch (StoreVT.getSimpleVT().SimpleTy) {
1433+
case MVT::v2f16:
1434+
case MVT::v2bf16:
1435+
case MVT::v2i16:
1436+
case MVT::v4i8:
1437+
case MVT::v2f32:
1438+
ScalarVT = StoreVT.getSimpleVT();
1439+
break;
1440+
default:
1441+
llvm_unreachable("Unsupported vector type for non-vector store");
1442+
}
14371443
}
14381444

1445+
unsigned ToTypeWidth = ScalarVT.getSizeInBits();
14391446
unsigned int ToType = getLdStRegType(ScalarVT);
14401447

14411448
// Create the machine instruction DAG

0 commit comments

Comments
 (0)