@@ -1074,6 +1074,7 @@ static int getLdStRegType(EVT VT) {
1074
1074
case MVT::bf16 :
1075
1075
case MVT::v2f16:
1076
1076
case MVT::v2bf16:
1077
+ case MVT::v2f32:
1077
1078
return NVPTX::PTXLdStInstCode::Untyped;
1078
1079
default :
1079
1080
return NVPTX::PTXLdStInstCode::Float;
@@ -1113,24 +1114,27 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
1113
1114
// Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
1114
1115
MVT SimpleVT = LoadedVT.getSimpleVT ();
1115
1116
MVT ScalarVT = SimpleVT.getScalarType ();
1116
- // Read at least 8 bits (predicates are stored as 8-bit values)
1117
- unsigned FromTypeWidth = std::max (8U , (unsigned )ScalarVT.getSizeInBits ());
1118
- unsigned int FromType;
1119
1117
1120
1118
// Vector Setting
1121
1119
unsigned VecType = NVPTX::PTXLdStInstCode::Scalar;
1122
1120
if (SimpleVT.isVector ()) {
1123
- if (Isv2x16VT (LoadedVT) || LoadedVT == MVT::v4i8)
1124
- // v2f16/v2bf16/v2i16 is loaded using ld.b32
1125
- FromTypeWidth = 32 ;
1126
- else if (LoadedVT == MVT::v2f32)
1127
- // v2f32 is loaded using ld.b64
1128
- FromTypeWidth = 64 ;
1129
- else
1130
- llvm_unreachable (" Unexpected vector type" );
1121
+ switch (LoadedVT.getSimpleVT ().SimpleTy ) {
1122
+ case MVT::v2f16:
1123
+ case MVT::v2bf16:
1124
+ case MVT::v2i16:
1125
+ case MVT::v4i8:
1126
+ case MVT::v2f32:
1127
+ ScalarVT = LoadedVT.getSimpleVT ();
1128
+ break ;
1129
+ default :
1130
+ llvm_unreachable (" Unsupported vector type for non-vector load" );
1131
+ }
1131
1132
}
1132
1133
1133
- if (PlainLoad && (PlainLoad->getExtensionType () == ISD::SEXTLOAD))
1134
+ // Read at least 8 bits (predicates are stored as 8-bit values)
1135
+ unsigned FromTypeWidth = std::max (8U , (unsigned )ScalarVT.getSizeInBits ());
1136
+ unsigned int FromType;
1137
+ if (PlainLoad && PlainLoad->getExtensionType () == ISD::SEXTLOAD)
1134
1138
FromType = NVPTX::PTXLdStInstCode::Signed;
1135
1139
else
1136
1140
FromType = getLdStRegType (ScalarVT);
@@ -1438,18 +1442,21 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
1438
1442
// Type Setting: toType + toTypeWidth
1439
1443
// - for integer type, always use 'u'
1440
1444
MVT ScalarVT = SimpleVT.getScalarType ();
1441
- unsigned ToTypeWidth = ScalarVT.getSizeInBits ();
1442
1445
if (SimpleVT.isVector ()) {
1443
- if (Isv2x16VT (StoreVT) || StoreVT == MVT::v4i8)
1444
- // v2x16 is stored using st.b32
1445
- ToTypeWidth = 32 ;
1446
- else if (StoreVT == MVT::v2f32)
1447
- // v2f32 is stored using st.b64
1448
- ToTypeWidth = 64 ;
1449
- else
1450
- llvm_unreachable (" Unexpected vector type" );
1446
+ switch (StoreVT.getSimpleVT ().SimpleTy ) {
1447
+ case MVT::v2f16:
1448
+ case MVT::v2bf16:
1449
+ case MVT::v2i16:
1450
+ case MVT::v4i8:
1451
+ case MVT::v2f32:
1452
+ ScalarVT = StoreVT.getSimpleVT ();
1453
+ break ;
1454
+ default :
1455
+ llvm_unreachable (" Unsupported vector type for non-vector store" );
1456
+ }
1451
1457
}
1452
1458
1459
+ unsigned ToTypeWidth = ScalarVT.getSizeInBits ();
1453
1460
unsigned int ToType = getLdStRegType (ScalarVT);
1454
1461
1455
1462
// Create the machine instruction DAG
0 commit comments