@@ -87,28 +87,13 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
87
87
}
88
88
};
89
89
90
- static Type convertGlobalMemrefTypeToEmitc (MemRefType type,
91
- const TypeConverter &typeConverter) {
92
- Type elementType = typeConverter.convertType (type.getElementType ());
93
- Type arrayTy = elementType;
94
- // Shape has the outermost dim at index 0, so need to walk it backwards
95
- auto shape = type.getShape ();
96
- if (shape.empty ()) {
97
- arrayTy = emitc::ArrayType::get ({1 }, arrayTy);
98
- } else {
99
- // For non-zero dimensions, use the original shape
100
- arrayTy = emitc::ArrayType::get (shape, arrayTy);
101
- }
102
- return arrayTy;
103
- }
104
-
105
90
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
106
91
using OpConversionPattern::OpConversionPattern;
107
92
108
93
LogicalResult
109
94
matchAndRewrite (memref::GlobalOp op, OpAdaptor operands,
110
95
ConversionPatternRewriter &rewriter) const override {
111
-
96
+ auto type = op. getType ();
112
97
if (!op.getType ().hasStaticShape ()) {
113
98
return rewriter.notifyMatchFailure (
114
99
op.getLoc (), " cannot transform global with dynamic shape" );
@@ -120,8 +105,23 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
120
105
op.getLoc (), " global variable with alignment requirement is "
121
106
" currently not supported" );
122
107
}
123
- auto resultTy =
124
- convertGlobalMemrefTypeToEmitc (op.getType (), *getTypeConverter ());
108
+ // auto resultTy =
109
+ // convertGlobalMemrefTypeToEmitc(op.getType(), *getTypeConverter());
110
+ Type resultTy;
111
+ Type elementType = getTypeConverter ()->convertType (type.getElementType ());
112
+ auto shape = type.getShape ();
113
+
114
+ if (shape.empty ()) {
115
+ if (emitc::isSupportedFloatType (elementType)) {
116
+ resultTy = rewriter.getF32Type ();
117
+ }
118
+ if (emitc::isSupportedIntegerType (elementType)) {
119
+ resultTy = rewriter.getIntegerType (elementType.getIntOrFloatBitWidth ());
120
+ }
121
+ } else {
122
+ resultTy = emitc::ArrayType::get (shape, elementType);
123
+ }
124
+
125
125
if (!resultTy) {
126
126
return rewriter.notifyMatchFailure (op.getLoc (),
127
127
" cannot convert result type" );
@@ -142,12 +142,7 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
142
142
Attribute initialValue = operands.getInitialValueAttr ();
143
143
if (op.getType ().getRank () == 0 ) {
144
144
auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue ());
145
- auto scalarValue = elementsAttr.getSplatValue <Attribute>();
146
-
147
- // Convert scalar value to single-element array
148
- initialValue = DenseElementsAttr::get (
149
- RankedTensorType::get ({1 }, elementsAttr.getElementType ()),
150
- {scalarValue});
145
+ initialValue = elementsAttr.getSplatValue <Attribute>();
151
146
}
152
147
if (isa_and_present<UnitAttr>(initialValue))
153
148
initialValue = {};
0 commit comments