Skip to content

Commit 127a1ed

Browse files
committed
Initial work on memref ops
1 parent 46e3ec0 commit 127a1ed

File tree

1 file changed

+36
-1
lines changed

1 file changed

+36
-1
lines changed

mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,31 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
7777
}
7878
};
7979

80+
struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
81+
using OpConversionPattern::OpConversionPattern;
82+
83+
LogicalResult
84+
matchAndRewrite(memref::CopyOp op, OpAdaptor operands,
85+
ConversionPatternRewriter &rewriter) const override {
86+
return failure();
87+
}
88+
};
89+
90+
static Type convertGlobalMemrefTypeToEmitc(MemRefType type,
91+
const TypeConverter &typeConverter) {
92+
Type elementType = typeConverter.convertType(type.getElementType());
93+
Type arrayTy = elementType;
94+
// Shape has the outermost dim at index 0, so need to walk it backwards
95+
auto shape = type.getShape();
96+
if (shape.empty()) {
97+
arrayTy = emitc::ArrayType::get({1}, arrayTy);
98+
} else {
99+
// For non-zero dimensions, use the original shape
100+
arrayTy = emitc::ArrayType::get(shape, arrayTy);
101+
}
102+
return arrayTy;
103+
}
104+
80105
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
81106
using OpConversionPattern::OpConversionPattern;
82107

@@ -95,7 +120,8 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
95120
op.getLoc(), "global variable with alignment requirement is "
96121
"currently not supported");
97122
}
98-
auto resultTy = getTypeConverter()->convertType(op.getType());
123+
auto resultTy =
124+
convertGlobalMemrefTypeToEmitc(op.getType(), *getTypeConverter());
99125
if (!resultTy) {
100126
return rewriter.notifyMatchFailure(op.getLoc(),
101127
"cannot convert result type");
@@ -114,6 +140,15 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
114140
bool externSpecifier = !staticSpecifier;
115141

116142
Attribute initialValue = operands.getInitialValueAttr();
143+
if (op.getType().getRank() == 0) {
144+
auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue());
145+
auto scalarValue = elementsAttr.getSplatValue<Attribute>();
146+
147+
// Convert scalar value to single-element array
148+
initialValue = DenseElementsAttr::get(
149+
RankedTensorType::get({1}, elementsAttr.getElementType()),
150+
{scalarValue});
151+
}
117152
if (isa_and_present<UnitAttr>(initialValue))
118153
initialValue = {};
119154

0 commit comments

Comments
 (0)