@@ -1060,6 +1060,7 @@ static int getLdStRegType(EVT VT) {
1060
1060
case MVT::bf16 :
1061
1061
case MVT::v2f16:
1062
1062
case MVT::v2bf16:
1063
+ case MVT::v2f32:
1063
1064
return NVPTX::PTXLdStInstCode::Untyped;
1064
1065
default :
1065
1066
return NVPTX::PTXLdStInstCode::Float;
@@ -1099,24 +1100,27 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
1099
1100
// Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
1100
1101
MVT SimpleVT = LoadedVT.getSimpleVT ();
1101
1102
MVT ScalarVT = SimpleVT.getScalarType ();
1102
- // Read at least 8 bits (predicates are stored as 8-bit values)
1103
- unsigned FromTypeWidth = std::max (8U , (unsigned )ScalarVT.getSizeInBits ());
1104
- unsigned int FromType;
1105
1103
1106
1104
// Vector Setting
1107
1105
unsigned VecType = NVPTX::PTXLdStInstCode::Scalar;
1108
1106
if (SimpleVT.isVector ()) {
1109
- if (Isv2x16VT (LoadedVT) || LoadedVT == MVT::v4i8)
1110
- // v2f16/v2bf16/v2i16 is loaded using ld.b32
1111
- FromTypeWidth = 32 ;
1112
- else if (LoadedVT == MVT::v2f32)
1113
- // v2f32 is loaded using ld.b64
1114
- FromTypeWidth = 64 ;
1115
- else
1116
- llvm_unreachable (" Unexpected vector type" );
1107
+ switch (LoadedVT.getSimpleVT ().SimpleTy ) {
1108
+ case MVT::v2f16:
1109
+ case MVT::v2bf16:
1110
+ case MVT::v2i16:
1111
+ case MVT::v4i8:
1112
+ case MVT::v2f32:
1113
+ ScalarVT = LoadedVT.getSimpleVT ();
1114
+ break ;
1115
+ default :
1116
+ llvm_unreachable (" Unsupported vector type for non-vector load" );
1117
+ }
1117
1118
}
1118
1119
1119
- if (PlainLoad && (PlainLoad->getExtensionType () == ISD::SEXTLOAD))
1120
+ // Read at least 8 bits (predicates are stored as 8-bit values)
1121
+ unsigned FromTypeWidth = std::max (8U , (unsigned )ScalarVT.getSizeInBits ());
1122
+ unsigned int FromType;
1123
+ if (PlainLoad && PlainLoad->getExtensionType () == ISD::SEXTLOAD)
1120
1124
FromType = NVPTX::PTXLdStInstCode::Signed;
1121
1125
else
1122
1126
FromType = getLdStRegType (ScalarVT);
@@ -1424,18 +1428,21 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
1424
1428
// Type Setting: toType + toTypeWidth
1425
1429
// - for integer type, always use 'u'
1426
1430
MVT ScalarVT = SimpleVT.getScalarType ();
1427
- unsigned ToTypeWidth = ScalarVT.getSizeInBits ();
1428
1431
if (SimpleVT.isVector ()) {
1429
- if (Isv2x16VT (StoreVT) || StoreVT == MVT::v4i8)
1430
- // v2x16 is stored using st.b32
1431
- ToTypeWidth = 32 ;
1432
- else if (StoreVT == MVT::v2f32)
1433
- // v2f32 is stored using st.b64
1434
- ToTypeWidth = 64 ;
1435
- else
1436
- llvm_unreachable (" Unexpected vector type" );
1432
+ switch (StoreVT.getSimpleVT ().SimpleTy ) {
1433
+ case MVT::v2f16:
1434
+ case MVT::v2bf16:
1435
+ case MVT::v2i16:
1436
+ case MVT::v4i8:
1437
+ case MVT::v2f32:
1438
+ ScalarVT = StoreVT.getSimpleVT ();
1439
+ break ;
1440
+ default :
1441
+ llvm_unreachable (" Unsupported vector type for non-vector store" );
1442
+ }
1437
1443
}
1438
1444
1445
+ unsigned ToTypeWidth = ScalarVT.getSizeInBits ();
1439
1446
unsigned int ToType = getLdStRegType (ScalarVT);
1440
1447
1441
1448
// Create the machine instruction DAG
0 commit comments