Skip to content

Commit 605fc89

Browse files
jacquesguanjacquesguan
authored andcommitted
[mlir][Arithmetic] Add common constant folder function for type cast ops.
This revision replaces current type cast constant folder with a new common type cast constant folder function template. It will cover all former folder and support fold the constant splat and vector. Differential Revision: https://reviews.llvm.org/D123489
1 parent 47a9528 commit 605fc89

File tree

3 files changed

+283
-69
lines changed

3 files changed

+283
-69
lines changed

mlir/include/mlir/Dialect/CommonFolders.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,56 @@ Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
108108
return {};
109109
}
110110

111+
template <
112+
class AttrElementT, class TargetAttrElementT,
113+
class ElementValueT = typename AttrElementT::ValueType,
114+
class TargetElementValueT = typename TargetAttrElementT::ValueType,
115+
class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
116+
Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
117+
const CalculationT &calculate) {
118+
assert(operands.size() == 1 && "Cast op takes one operand");
119+
if (!operands[0])
120+
return {};
121+
122+
if (operands[0].isa<AttrElementT>()) {
123+
auto op = operands[0].cast<AttrElementT>();
124+
bool castStatus = true;
125+
auto res = calculate(op.getValue(), castStatus);
126+
if (!castStatus)
127+
return {};
128+
return TargetAttrElementT::get(resType, res);
129+
}
130+
if (operands[0].isa<SplatElementsAttr>()) {
131+
// The operand is a splat so we can avoid expanding the values out and
132+
// just fold based on the splat value.
133+
auto op = operands[0].cast<SplatElementsAttr>();
134+
bool castStatus = true;
135+
auto elementResult =
136+
calculate(op.getSplatValue<ElementValueT>(), castStatus);
137+
if (!castStatus)
138+
return {};
139+
return DenseElementsAttr::get(resType, elementResult);
140+
}
141+
if (operands[0].isa<ElementsAttr>()) {
142+
// Operand is ElementsAttr-derived; perform an element-wise fold by
143+
// expanding the value.
144+
auto op = operands[0].cast<ElementsAttr>();
145+
bool castStatus = true;
146+
auto opIt = op.value_begin<ElementValueT>();
147+
SmallVector<TargetElementValueT> elementResults;
148+
elementResults.reserve(op.getNumElements());
149+
for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
150+
auto elt = calculate(*opIt, castStatus);
151+
if (!castStatus)
152+
return {};
153+
elementResults.push_back(elt);
154+
}
155+
156+
return DenseElementsAttr::get(resType, elementResults);
157+
}
158+
return {};
159+
}
160+
111161
} // namespace mlir
112162

113163
#endif // MLIR_DIALECT_COMMONFOLDERS_H

mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp

Lines changed: 90 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -875,16 +875,20 @@ static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
875875
//===----------------------------------------------------------------------===//
876876

877877
OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
878-
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
879-
return IntegerAttr::get(
880-
getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
881-
882878
if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
883879
getInMutable().assign(lhs.getIn());
884880
return getResult();
885881
}
886-
887-
return {};
882+
Type resType = getType();
883+
unsigned bitWidth;
884+
if (auto shapedType = resType.dyn_cast<ShapedType>())
885+
bitWidth = shapedType.getElementTypeBitWidth();
886+
else
887+
bitWidth = resType.getIntOrFloatBitWidth();
888+
return constFoldCastOp<IntegerAttr, IntegerAttr>(
889+
operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
890+
return a.zext(bitWidth);
891+
});
888892
}
889893

890894
bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
@@ -900,16 +904,20 @@ LogicalResult arith::ExtUIOp::verify() {
900904
//===----------------------------------------------------------------------===//
901905

902906
OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
903-
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
904-
return IntegerAttr::get(
905-
getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
906-
907907
if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
908908
getInMutable().assign(lhs.getIn());
909909
return getResult();
910910
}
911-
912-
return {};
911+
Type resType = getType();
912+
unsigned bitWidth;
913+
if (auto shapedType = resType.dyn_cast<ShapedType>())
914+
bitWidth = shapedType.getElementTypeBitWidth();
915+
else
916+
bitWidth = resType.getIntOrFloatBitWidth();
917+
return constFoldCastOp<IntegerAttr, IntegerAttr>(
918+
operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
919+
return a.sext(bitWidth);
920+
});
913921
}
914922

915923
bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
@@ -954,15 +962,17 @@ OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
954962
return getResult();
955963
}
956964

957-
if (!operands[0])
958-
return {};
965+
Type resType = getType();
966+
unsigned bitWidth;
967+
if (auto shapedType = resType.dyn_cast<ShapedType>())
968+
bitWidth = shapedType.getElementTypeBitWidth();
969+
else
970+
bitWidth = resType.getIntOrFloatBitWidth();
959971

960-
if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
961-
return IntegerAttr::get(
962-
getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
963-
}
964-
965-
return {};
972+
return constFoldCastOp<IntegerAttr, IntegerAttr>(
973+
operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
974+
return a.trunc(bitWidth);
975+
});
966976
}
967977

968978
bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
@@ -1048,15 +1058,21 @@ bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
10481058
}
10491059

10501060
OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
1051-
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
1052-
const APInt &api = lhs.getValue();
1053-
FloatType floatTy = getType().cast<FloatType>();
1054-
APFloat apf(floatTy.getFloatSemantics(),
1055-
APInt::getZero(floatTy.getWidth()));
1056-
apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven);
1057-
return FloatAttr::get(floatTy, apf);
1058-
}
1059-
return {};
1061+
Type resType = getType();
1062+
Type resEleType;
1063+
if (auto shapedType = resType.dyn_cast<ShapedType>())
1064+
resEleType = shapedType.getElementType();
1065+
else
1066+
resEleType = resType;
1067+
return constFoldCastOp<IntegerAttr, FloatAttr>(
1068+
operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1069+
FloatType floatTy = resEleType.cast<FloatType>();
1070+
APFloat apf(floatTy.getFloatSemantics(),
1071+
APInt::getZero(floatTy.getWidth()));
1072+
apf.convertFromAPInt(a, /*IsSigned=*/false,
1073+
APFloat::rmNearestTiesToEven);
1074+
return apf;
1075+
});
10601076
}
10611077

10621078
//===----------------------------------------------------------------------===//
@@ -1068,15 +1084,21 @@ bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
10681084
}
10691085

10701086
OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
1071-
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
1072-
const APInt &api = lhs.getValue();
1073-
FloatType floatTy = getType().cast<FloatType>();
1074-
APFloat apf(floatTy.getFloatSemantics(),
1075-
APInt::getZero(floatTy.getWidth()));
1076-
apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven);
1077-
return FloatAttr::get(floatTy, apf);
1078-
}
1079-
return {};
1087+
Type resType = getType();
1088+
Type resEleType;
1089+
if (auto shapedType = resType.dyn_cast<ShapedType>())
1090+
resEleType = shapedType.getElementType();
1091+
else
1092+
resEleType = resType;
1093+
return constFoldCastOp<IntegerAttr, FloatAttr>(
1094+
operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1095+
FloatType floatTy = resEleType.cast<FloatType>();
1096+
APFloat apf(floatTy.getFloatSemantics(),
1097+
APInt::getZero(floatTy.getWidth()));
1098+
apf.convertFromAPInt(a, /*IsSigned=*/true,
1099+
APFloat::rmNearestTiesToEven);
1100+
return apf;
1101+
});
10801102
}
10811103
//===----------------------------------------------------------------------===//
10821104
// FPToUIOp
@@ -1087,21 +1109,21 @@ bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
10871109
}
10881110

10891111
OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
1090-
if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
1091-
const APFloat &apf = lhs.getValue();
1092-
IntegerType intTy = getType().cast<IntegerType>();
1093-
bool ignored;
1094-
APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
1095-
if (APFloat::opInvalidOp ==
1096-
apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
1097-
// Undefined behavior invoked - the destination type can't represent
1098-
// the input constant.
1099-
return {};
1100-
}
1101-
return IntegerAttr::get(getType(), api);
1102-
}
1103-
1104-
return {};
1112+
Type resType = getType();
1113+
Type resEleType;
1114+
if (auto shapedType = resType.dyn_cast<ShapedType>())
1115+
resEleType = shapedType.getElementType();
1116+
else
1117+
resEleType = resType;
1118+
return constFoldCastOp<FloatAttr, IntegerAttr>(
1119+
operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1120+
IntegerType intTy = resEleType.cast<IntegerType>();
1121+
bool ignored;
1122+
APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
1123+
castStatus = APFloat::opInvalidOp !=
1124+
a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1125+
return api;
1126+
});
11051127
}
11061128

11071129
//===----------------------------------------------------------------------===//
@@ -1113,21 +1135,21 @@ bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
11131135
}
11141136

11151137
OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
1116-
if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
1117-
const APFloat &apf = lhs.getValue();
1118-
IntegerType intTy = getType().cast<IntegerType>();
1119-
bool ignored;
1120-
APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
1121-
if (APFloat::opInvalidOp ==
1122-
apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
1123-
// Undefined behavior invoked - the destination type can't represent
1124-
// the input constant.
1125-
return {};
1126-
}
1127-
return IntegerAttr::get(getType(), api);
1128-
}
1129-
1130-
return {};
1138+
Type resType = getType();
1139+
Type resEleType;
1140+
if (auto shapedType = resType.dyn_cast<ShapedType>())
1141+
resEleType = shapedType.getElementType();
1142+
else
1143+
resEleType = resType;
1144+
return constFoldCastOp<FloatAttr, IntegerAttr>(
1145+
operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1146+
IntegerType intTy = resEleType.cast<IntegerType>();
1147+
bool ignored;
1148+
APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
1149+
castStatus = APFloat::opInvalidOp !=
1150+
a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1151+
return api;
1152+
});
11311153
}
11321154

11331155
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)