@@ -875,16 +875,20 @@ static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
875
875
// ===----------------------------------------------------------------------===//
876
876
877
877
OpFoldResult arith::ExtUIOp::fold (ArrayRef<Attribute> operands) {
878
- if (auto lhs = operands[0 ].dyn_cast_or_null <IntegerAttr>())
879
- return IntegerAttr::get (
880
- getType (), lhs.getValue ().zext (getType ().getIntOrFloatBitWidth ()));
881
-
882
878
if (auto lhs = getIn ().getDefiningOp <ExtUIOp>()) {
883
879
getInMutable ().assign (lhs.getIn ());
884
880
return getResult ();
885
881
}
886
-
887
- return {};
882
+ Type resType = getType ();
883
+ unsigned bitWidth;
884
+ if (auto shapedType = resType.dyn_cast <ShapedType>())
885
+ bitWidth = shapedType.getElementTypeBitWidth ();
886
+ else
887
+ bitWidth = resType.getIntOrFloatBitWidth ();
888
+ return constFoldCastOp<IntegerAttr, IntegerAttr>(
889
+ operands, getType (), [bitWidth](const APInt &a, bool &castStatus) {
890
+ return a.zext (bitWidth);
891
+ });
888
892
}
889
893
890
894
bool arith::ExtUIOp::areCastCompatible (TypeRange inputs, TypeRange outputs) {
@@ -900,16 +904,20 @@ LogicalResult arith::ExtUIOp::verify() {
900
904
// ===----------------------------------------------------------------------===//
901
905
902
906
OpFoldResult arith::ExtSIOp::fold (ArrayRef<Attribute> operands) {
903
- if (auto lhs = operands[0 ].dyn_cast_or_null <IntegerAttr>())
904
- return IntegerAttr::get (
905
- getType (), lhs.getValue ().sext (getType ().getIntOrFloatBitWidth ()));
906
-
907
907
if (auto lhs = getIn ().getDefiningOp <ExtSIOp>()) {
908
908
getInMutable ().assign (lhs.getIn ());
909
909
return getResult ();
910
910
}
911
-
912
- return {};
911
+ Type resType = getType ();
912
+ unsigned bitWidth;
913
+ if (auto shapedType = resType.dyn_cast <ShapedType>())
914
+ bitWidth = shapedType.getElementTypeBitWidth ();
915
+ else
916
+ bitWidth = resType.getIntOrFloatBitWidth ();
917
+ return constFoldCastOp<IntegerAttr, IntegerAttr>(
918
+ operands, getType (), [bitWidth](const APInt &a, bool &castStatus) {
919
+ return a.sext (bitWidth);
920
+ });
913
921
}
914
922
915
923
bool arith::ExtSIOp::areCastCompatible (TypeRange inputs, TypeRange outputs) {
@@ -954,15 +962,17 @@ OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
954
962
return getResult ();
955
963
}
956
964
957
- if (!operands[0 ])
958
- return {};
965
+ Type resType = getType ();
966
+ unsigned bitWidth;
967
+ if (auto shapedType = resType.dyn_cast <ShapedType>())
968
+ bitWidth = shapedType.getElementTypeBitWidth ();
969
+ else
970
+ bitWidth = resType.getIntOrFloatBitWidth ();
959
971
960
- if (auto lhs = operands[0 ].dyn_cast <IntegerAttr>()) {
961
- return IntegerAttr::get (
962
- getType (), lhs.getValue ().trunc (getType ().getIntOrFloatBitWidth ()));
963
- }
964
-
965
- return {};
972
+ return constFoldCastOp<IntegerAttr, IntegerAttr>(
973
+ operands, getType (), [bitWidth](const APInt &a, bool &castStatus) {
974
+ return a.trunc (bitWidth);
975
+ });
966
976
}
967
977
968
978
bool arith::TruncIOp::areCastCompatible (TypeRange inputs, TypeRange outputs) {
@@ -1048,15 +1058,21 @@ bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1048
1058
}
1049
1059
1050
1060
OpFoldResult arith::UIToFPOp::fold (ArrayRef<Attribute> operands) {
1051
- if (auto lhs = operands[0 ].dyn_cast_or_null <IntegerAttr>()) {
1052
- const APInt &api = lhs.getValue ();
1053
- FloatType floatTy = getType ().cast <FloatType>();
1054
- APFloat apf (floatTy.getFloatSemantics (),
1055
- APInt::getZero (floatTy.getWidth ()));
1056
- apf.convertFromAPInt (api, /* IsSigned=*/ false , APFloat::rmNearestTiesToEven);
1057
- return FloatAttr::get (floatTy, apf);
1058
- }
1059
- return {};
1061
+ Type resType = getType ();
1062
+ Type resEleType;
1063
+ if (auto shapedType = resType.dyn_cast <ShapedType>())
1064
+ resEleType = shapedType.getElementType ();
1065
+ else
1066
+ resEleType = resType;
1067
+ return constFoldCastOp<IntegerAttr, FloatAttr>(
1068
+ operands, getType (), [&resEleType](const APInt &a, bool &castStatus) {
1069
+ FloatType floatTy = resEleType.cast <FloatType>();
1070
+ APFloat apf (floatTy.getFloatSemantics (),
1071
+ APInt::getZero (floatTy.getWidth ()));
1072
+ apf.convertFromAPInt (a, /* IsSigned=*/ false ,
1073
+ APFloat::rmNearestTiesToEven);
1074
+ return apf;
1075
+ });
1060
1076
}
1061
1077
1062
1078
// ===----------------------------------------------------------------------===//
@@ -1068,15 +1084,21 @@ bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1068
1084
}
1069
1085
1070
1086
OpFoldResult arith::SIToFPOp::fold (ArrayRef<Attribute> operands) {
1071
- if (auto lhs = operands[0 ].dyn_cast_or_null <IntegerAttr>()) {
1072
- const APInt &api = lhs.getValue ();
1073
- FloatType floatTy = getType ().cast <FloatType>();
1074
- APFloat apf (floatTy.getFloatSemantics (),
1075
- APInt::getZero (floatTy.getWidth ()));
1076
- apf.convertFromAPInt (api, /* IsSigned=*/ true , APFloat::rmNearestTiesToEven);
1077
- return FloatAttr::get (floatTy, apf);
1078
- }
1079
- return {};
1087
+ Type resType = getType ();
1088
+ Type resEleType;
1089
+ if (auto shapedType = resType.dyn_cast <ShapedType>())
1090
+ resEleType = shapedType.getElementType ();
1091
+ else
1092
+ resEleType = resType;
1093
+ return constFoldCastOp<IntegerAttr, FloatAttr>(
1094
+ operands, getType (), [&resEleType](const APInt &a, bool &castStatus) {
1095
+ FloatType floatTy = resEleType.cast <FloatType>();
1096
+ APFloat apf (floatTy.getFloatSemantics (),
1097
+ APInt::getZero (floatTy.getWidth ()));
1098
+ apf.convertFromAPInt (a, /* IsSigned=*/ true ,
1099
+ APFloat::rmNearestTiesToEven);
1100
+ return apf;
1101
+ });
1080
1102
}
1081
1103
// ===----------------------------------------------------------------------===//
1082
1104
// FPToUIOp
@@ -1087,21 +1109,21 @@ bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1087
1109
}
1088
1110
1089
1111
OpFoldResult arith::FPToUIOp::fold (ArrayRef<Attribute> operands) {
1090
- if ( auto lhs = operands[ 0 ]. dyn_cast_or_null <FloatAttr>()) {
1091
- const APFloat &apf = lhs. getValue () ;
1092
- IntegerType intTy = getType (). cast <IntegerType >();
1093
- bool ignored ;
1094
- APSInt api (intTy. getWidth (), /* isUnsigned= */ true );
1095
- if (APFloat::opInvalidOp ==
1096
- apf. convertToInteger (api, APFloat::rmTowardZero, &ignored)) {
1097
- // Undefined behavior invoked - the destination type can't represent
1098
- // the input constant.
1099
- return {} ;
1100
- }
1101
- return IntegerAttr::get ( getType (), api);
1102
- }
1103
-
1104
- return {} ;
1112
+ Type resType = getType ();
1113
+ Type resEleType ;
1114
+ if ( auto shapedType = resType. dyn_cast <ShapedType >())
1115
+ resEleType = shapedType. getElementType () ;
1116
+ else
1117
+ resEleType = resType;
1118
+ return constFoldCastOp<FloatAttr, IntegerAttr>(
1119
+ operands, getType (), [&resEleType]( const APFloat &a, bool &castStatus) {
1120
+ IntegerType intTy = resEleType. cast <IntegerType>();
1121
+ bool ignored ;
1122
+ APSInt api (intTy. getWidth (), /* isUnsigned= */ true );
1123
+ castStatus = APFloat::opInvalidOp !=
1124
+ a. convertToInteger (api, APFloat::rmTowardZero, &ignored);
1125
+ return api;
1126
+ }) ;
1105
1127
}
1106
1128
1107
1129
// ===----------------------------------------------------------------------===//
@@ -1113,21 +1135,21 @@ bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1113
1135
}
1114
1136
1115
1137
OpFoldResult arith::FPToSIOp::fold (ArrayRef<Attribute> operands) {
1116
- if ( auto lhs = operands[ 0 ]. dyn_cast_or_null <FloatAttr>()) {
1117
- const APFloat &apf = lhs. getValue () ;
1118
- IntegerType intTy = getType (). cast <IntegerType >();
1119
- bool ignored ;
1120
- APSInt api (intTy. getWidth (), /* isUnsigned= */ false );
1121
- if (APFloat::opInvalidOp ==
1122
- apf. convertToInteger (api, APFloat::rmTowardZero, &ignored)) {
1123
- // Undefined behavior invoked - the destination type can't represent
1124
- // the input constant.
1125
- return {} ;
1126
- }
1127
- return IntegerAttr::get ( getType (), api);
1128
- }
1129
-
1130
- return {} ;
1138
+ Type resType = getType ();
1139
+ Type resEleType ;
1140
+ if ( auto shapedType = resType. dyn_cast <ShapedType >())
1141
+ resEleType = shapedType. getElementType () ;
1142
+ else
1143
+ resEleType = resType;
1144
+ return constFoldCastOp<FloatAttr, IntegerAttr>(
1145
+ operands, getType (), [&resEleType]( const APFloat &a, bool &castStatus) {
1146
+ IntegerType intTy = resEleType. cast <IntegerType>();
1147
+ bool ignored ;
1148
+ APSInt api (intTy. getWidth (), /* isUnsigned= */ false );
1149
+ castStatus = APFloat::opInvalidOp !=
1150
+ a. convertToInteger (api, APFloat::rmTowardZero, &ignored);
1151
+ return api;
1152
+ }) ;
1131
1153
}
1132
1154
1133
1155
// ===----------------------------------------------------------------------===//
0 commit comments