16
16
#include " mlir/Dialect/EmitC/IR/EmitC.h"
17
17
#include " mlir/Dialect/MemRef/IR/MemRef.h"
18
18
#include " mlir/IR/Builders.h"
19
+ #include " mlir/IR/BuiltinTypes.h"
19
20
#include " mlir/IR/PatternMatch.h"
21
+ #include " mlir/IR/TypeRange.h"
20
22
#include " mlir/Transforms/DialectConversion.h"
21
23
22
24
using namespace mlir ;
@@ -77,23 +79,13 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
77
79
}
78
80
};
79
81
80
- struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
81
- using OpConversionPattern::OpConversionPattern;
82
-
83
- LogicalResult
84
- matchAndRewrite (memref::CopyOp op, OpAdaptor operands,
85
- ConversionPatternRewriter &rewriter) const override {
86
- return failure ();
87
- }
88
- };
89
-
90
82
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
91
83
using OpConversionPattern::OpConversionPattern;
92
84
93
85
LogicalResult
94
86
matchAndRewrite (memref::GlobalOp op, OpAdaptor operands,
95
87
ConversionPatternRewriter &rewriter) const override {
96
- auto type = op.getType ();
88
+ MemRefType type = op.getType ();
97
89
if (!op.getType ().hasStaticShape ()) {
98
90
return rewriter.notifyMatchFailure (
99
91
op.getLoc (), " cannot transform global with dynamic shape" );
@@ -105,22 +97,12 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
105
97
op.getLoc (), " global variable with alignment requirement is "
106
98
" currently not supported" );
107
99
}
108
- // auto resultTy =
109
- // convertGlobalMemrefTypeToEmitc(op.getType(), *getTypeConverter());
100
+
110
101
Type resultTy;
111
- Type elementType = getTypeConverter ()->convertType (type.getElementType ());
112
- auto shape = type.getShape ();
113
-
114
- if (shape.empty ()) {
115
- if (emitc::isSupportedFloatType (elementType)) {
116
- resultTy = rewriter.getF32Type ();
117
- }
118
- if (emitc::isSupportedIntegerType (elementType)) {
119
- resultTy = rewriter.getIntegerType (elementType.getIntOrFloatBitWidth ());
120
- }
121
- } else {
122
- resultTy = emitc::ArrayType::get (shape, elementType);
123
- }
102
+ if (type.getRank () == 0 )
103
+ resultTy = getTypeConverter ()->convertType (type.getElementType ());
104
+ else
105
+ resultTy = getTypeConverter ()->convertType (type);
124
106
125
107
if (!resultTy) {
126
108
return rewriter.notifyMatchFailure (op.getLoc (),
@@ -140,7 +122,7 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
140
122
bool externSpecifier = !staticSpecifier;
141
123
142
124
Attribute initialValue = operands.getInitialValueAttr ();
143
- if (op. getType () .getRank () == 0 ) {
125
+ if (type .getRank () == 0 ) {
144
126
auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue ());
145
127
initialValue = elementsAttr.getSplatValue <Attribute>();
146
128
}
@@ -162,7 +144,17 @@ struct ConvertGetGlobal final
162
144
matchAndRewrite (memref::GetGlobalOp op, OpAdaptor operands,
163
145
ConversionPatternRewriter &rewriter) const override {
164
146
165
- auto resultTy = getTypeConverter ()->convertType (op.getType ());
147
+ MemRefType type = op.getType ();
148
+ Type resultTy;
149
+ if (type.getRank () == 0 )
150
+ resultTy = emitc::LValueType::get (
151
+ getTypeConverter ()->convertType (type.getElementType ()));
152
+ else
153
+ resultTy = getTypeConverter ()->convertType (type);
154
+
155
+ if (!resultTy)
156
+ return rewriter.notifyMatchFailure (op.getLoc (), " cannot convert type" );
157
+
166
158
if (!resultTy) {
167
159
return rewriter.notifyMatchFailure (op.getLoc (),
168
160
" cannot convert result type" );
0 commit comments