Skip to content

Commit 36b61a6

Browse files
committed
global and getGlobal
1 parent c237251 commit 36b61a6

File tree

2 files changed

+24
-28
lines changed

2 files changed

+24
-28
lines changed

mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
#include "mlir/Dialect/EmitC/IR/EmitC.h"
1717
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1818
#include "mlir/IR/Builders.h"
19+
#include "mlir/IR/BuiltinTypes.h"
1920
#include "mlir/IR/PatternMatch.h"
21+
#include "mlir/IR/TypeRange.h"
2022
#include "mlir/Transforms/DialectConversion.h"
2123

2224
using namespace mlir;
@@ -77,23 +79,13 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
7779
}
7880
};
7981

80-
struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
81-
using OpConversionPattern::OpConversionPattern;
82-
83-
LogicalResult
84-
matchAndRewrite(memref::CopyOp op, OpAdaptor operands,
85-
ConversionPatternRewriter &rewriter) const override {
86-
return failure();
87-
}
88-
};
89-
9082
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
9183
using OpConversionPattern::OpConversionPattern;
9284

9385
LogicalResult
9486
matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
9587
ConversionPatternRewriter &rewriter) const override {
96-
auto type = op.getType();
88+
MemRefType type = op.getType();
9789
if (!op.getType().hasStaticShape()) {
9890
return rewriter.notifyMatchFailure(
9991
op.getLoc(), "cannot transform global with dynamic shape");
@@ -105,22 +97,12 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
10597
op.getLoc(), "global variable with alignment requirement is "
10698
"currently not supported");
10799
}
108-
// auto resultTy =
109-
// convertGlobalMemrefTypeToEmitc(op.getType(), *getTypeConverter());
100+
110101
Type resultTy;
111-
Type elementType = getTypeConverter()->convertType(type.getElementType());
112-
auto shape = type.getShape();
113-
114-
if (shape.empty()) {
115-
if (emitc::isSupportedFloatType(elementType)) {
116-
resultTy = rewriter.getF32Type();
117-
}
118-
if (emitc::isSupportedIntegerType(elementType)) {
119-
resultTy = rewriter.getIntegerType(elementType.getIntOrFloatBitWidth());
120-
}
121-
} else {
122-
resultTy = emitc::ArrayType::get(shape, elementType);
123-
}
102+
if (type.getRank() == 0)
103+
resultTy = getTypeConverter()->convertType(type.getElementType());
104+
else
105+
resultTy = getTypeConverter()->convertType(type);
124106

125107
if (!resultTy) {
126108
return rewriter.notifyMatchFailure(op.getLoc(),
@@ -140,7 +122,7 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
140122
bool externSpecifier = !staticSpecifier;
141123

142124
Attribute initialValue = operands.getInitialValueAttr();
143-
if (op.getType().getRank() == 0) {
125+
if (type.getRank() == 0) {
144126
auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue());
145127
initialValue = elementsAttr.getSplatValue<Attribute>();
146128
}
@@ -162,7 +144,17 @@ struct ConvertGetGlobal final
162144
matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
163145
ConversionPatternRewriter &rewriter) const override {
164146

165-
auto resultTy = getTypeConverter()->convertType(op.getType());
147+
MemRefType type = op.getType();
148+
Type resultTy;
149+
if (type.getRank() == 0)
150+
resultTy = emitc::LValueType::get(
151+
getTypeConverter()->convertType(type.getElementType()));
152+
else
153+
resultTy = getTypeConverter()->convertType(type);
154+
155+
if (!resultTy)
156+
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
157+
166158
if (!resultTy) {
167159
return rewriter.notifyMatchFailure(op.getLoc(),
168160
"cannot convert result type");

mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ func.func @memref_load(%buff : memref<4x8xf32>, %i: index, %j: index) -> f32 {
4141
module @globals {
4242
memref.global "private" constant @internal_global : memref<3x7xf32> = dense<4.0>
4343
// CHECK-NEXT: emitc.global static const @internal_global : !emitc.array<3x7xf32> = dense<4.000000e+00>
44+
memref.global "private" constant @__constant_xi32 : memref<i32> = dense<-1>
45+
// CHECK-NEXT: emitc.global static const @__constant_xi32 : i32 = -1
4446
memref.global @public_global : memref<3x7xf32>
4547
// CHECK-NEXT: emitc.global extern @public_global : !emitc.array<3x7xf32>
4648
memref.global @uninitialized_global : memref<3x7xf32> = uninitialized
@@ -50,6 +52,8 @@ module @globals {
5052
func.func @use_global() {
5153
// CHECK-NEXT: emitc.get_global @public_global : !emitc.array<3x7xf32>
5254
%0 = memref.get_global @public_global : memref<3x7xf32>
55+
// CHECK- NEXT: emitc.get_global @__constant_xi32 : !emitc.lvalue<i32>
56+
%1 = memref.get_global @__constant_xi32 : memref<i32>
5357
return
5458
}
5559
}

0 commit comments

Comments
 (0)