Skip to content

Commit c237251

Browse files
committed
Convert scalars to constants
1 parent 127a1ed commit c237251

File tree

1 file changed

+19
-24
lines changed

1 file changed

+19
-24
lines changed

mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -87,28 +87,13 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
8787
}
8888
};
8989

90-
static Type convertGlobalMemrefTypeToEmitc(MemRefType type,
91-
const TypeConverter &typeConverter) {
92-
Type elementType = typeConverter.convertType(type.getElementType());
93-
Type arrayTy = elementType;
94-
// Shape has the outermost dim at index 0, so need to walk it backwards
95-
auto shape = type.getShape();
96-
if (shape.empty()) {
97-
arrayTy = emitc::ArrayType::get({1}, arrayTy);
98-
} else {
99-
// For non-zero dimensions, use the original shape
100-
arrayTy = emitc::ArrayType::get(shape, arrayTy);
101-
}
102-
return arrayTy;
103-
}
104-
10590
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
10691
using OpConversionPattern::OpConversionPattern;
10792

10893
LogicalResult
10994
matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
11095
ConversionPatternRewriter &rewriter) const override {
111-
96+
auto type = op.getType();
11297
if (!op.getType().hasStaticShape()) {
11398
return rewriter.notifyMatchFailure(
11499
op.getLoc(), "cannot transform global with dynamic shape");
@@ -120,8 +105,23 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
120105
op.getLoc(), "global variable with alignment requirement is "
121106
"currently not supported");
122107
}
123-
auto resultTy =
124-
convertGlobalMemrefTypeToEmitc(op.getType(), *getTypeConverter());
108+
// auto resultTy =
109+
// convertGlobalMemrefTypeToEmitc(op.getType(), *getTypeConverter());
110+
Type resultTy;
111+
Type elementType = getTypeConverter()->convertType(type.getElementType());
112+
auto shape = type.getShape();
113+
114+
if (shape.empty()) {
115+
if (emitc::isSupportedFloatType(elementType)) {
116+
resultTy = rewriter.getF32Type();
117+
}
118+
if (emitc::isSupportedIntegerType(elementType)) {
119+
resultTy = rewriter.getIntegerType(elementType.getIntOrFloatBitWidth());
120+
}
121+
} else {
122+
resultTy = emitc::ArrayType::get(shape, elementType);
123+
}
124+
125125
if (!resultTy) {
126126
return rewriter.notifyMatchFailure(op.getLoc(),
127127
"cannot convert result type");
@@ -142,12 +142,7 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
142142
Attribute initialValue = operands.getInitialValueAttr();
143143
if (op.getType().getRank() == 0) {
144144
auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue());
145-
auto scalarValue = elementsAttr.getSplatValue<Attribute>();
146-
147-
// Convert scalar value to single-element array
148-
initialValue = DenseElementsAttr::get(
149-
RankedTensorType::get({1}, elementsAttr.getElementType()),
150-
{scalarValue});
145+
initialValue = elementsAttr.getSplatValue<Attribute>();
151146
}
152147
if (isa_and_present<UnitAttr>(initialValue))
153148
initialValue = {};

0 commit comments

Comments
 (0)