@@ -77,6 +77,31 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
77
77
}
78
78
};
79
79
80
+ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
81
+ using OpConversionPattern::OpConversionPattern;
82
+
83
+ LogicalResult
84
+ matchAndRewrite (memref::CopyOp op, OpAdaptor operands,
85
+ ConversionPatternRewriter &rewriter) const override {
86
+ return failure ();
87
+ }
88
+ };
89
+
90
+ static Type convertGlobalMemrefTypeToEmitc (MemRefType type,
91
+ const TypeConverter &typeConverter) {
92
+ Type elementType = typeConverter.convertType (type.getElementType ());
93
+ Type arrayTy = elementType;
94
+ // Shape has the outermost dim at index 0, so need to walk it backwards
95
+ auto shape = type.getShape ();
96
+ if (shape.empty ()) {
97
+ arrayTy = emitc::ArrayType::get ({1 }, arrayTy);
98
+ } else {
99
+ // For non-zero dimensions, use the original shape
100
+ arrayTy = emitc::ArrayType::get (shape, arrayTy);
101
+ }
102
+ return arrayTy;
103
+ }
104
+
80
105
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
81
106
using OpConversionPattern::OpConversionPattern;
82
107
@@ -95,7 +120,8 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
95
120
op.getLoc (), " global variable with alignment requirement is "
96
121
" currently not supported" );
97
122
}
98
- auto resultTy = getTypeConverter ()->convertType (op.getType ());
123
+ auto resultTy =
124
+ convertGlobalMemrefTypeToEmitc (op.getType (), *getTypeConverter ());
99
125
if (!resultTy) {
100
126
return rewriter.notifyMatchFailure (op.getLoc (),
101
127
" cannot convert result type" );
@@ -114,6 +140,15 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
114
140
bool externSpecifier = !staticSpecifier;
115
141
116
142
Attribute initialValue = operands.getInitialValueAttr ();
143
+ if (op.getType ().getRank () == 0 ) {
144
+ auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue ());
145
+ auto scalarValue = elementsAttr.getSplatValue <Attribute>();
146
+
147
+ // Convert scalar value to single-element array
148
+ initialValue = DenseElementsAttr::get (
149
+ RankedTensorType::get ({1 }, elementsAttr.getElementType ()),
150
+ {scalarValue});
151
+ }
117
152
if (isa_and_present<UnitAttr>(initialValue))
118
153
initialValue = {};
119
154
0 commit comments