diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index db244d1d1cac8..e55c8e48ad105 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -16,7 +16,9 @@ #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeRange.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; @@ -83,7 +85,7 @@ struct ConvertGlobal final : public OpConversionPattern { LogicalResult matchAndRewrite(memref::GlobalOp op, OpAdaptor operands, ConversionPatternRewriter &rewriter) const override { - + MemRefType type = op.getType(); if (!op.getType().hasStaticShape()) { return rewriter.notifyMatchFailure( op.getLoc(), "cannot transform global with dynamic shape"); @@ -95,7 +97,13 @@ struct ConvertGlobal final : public OpConversionPattern { op.getLoc(), "global variable with alignment requirement is " "currently not supported"); } - auto resultTy = getTypeConverter()->convertType(op.getType()); + + Type resultTy; + if (type.getRank() == 0) + resultTy = getTypeConverter()->convertType(type.getElementType()); + else + resultTy = getTypeConverter()->convertType(type); + if (!resultTy) { return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert result type"); @@ -114,6 +122,10 @@ struct ConvertGlobal final : public OpConversionPattern { bool externSpecifier = !staticSpecifier; Attribute initialValue = operands.getInitialValueAttr(); + if (type.getRank() == 0) { + auto elementsAttr = llvm::cast(*op.getInitialValue()); + initialValue = elementsAttr.getSplatValue(); + } if (isa_and_present(initialValue)) initialValue = {}; @@ -132,7 +144,17 @@ struct ConvertGetGlobal final matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands, ConversionPatternRewriter &rewriter) const override { - auto resultTy = getTypeConverter()->convertType(op.getType()); + MemRefType type = op.getType(); + Type resultTy; + if (type.getRank() == 0) + resultTy = emitc::LValueType::get( + getTypeConverter()->convertType(type.getElementType())); + else + resultTy = getTypeConverter()->convertType(type); + + if (!resultTy) + return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type"); + if (!resultTy) { return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert result type"); diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir index d37fd1de90add..445a28534325a 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir @@ -41,6 +41,8 @@ func.func @memref_load(%buff : memref<4x8xf32>, %i: index, %j: index) -> f32 { module @globals { memref.global "private" constant @internal_global : memref<3x7xf32> = dense<4.0> // CHECK-NEXT: emitc.global static const @internal_global : !emitc.array<3x7xf32> = dense<4.000000e+00> + memref.global "private" constant @__constant_xi32 : memref = dense<-1> + // CHECK-NEXT: emitc.global static const @__constant_xi32 : i32 = -1 memref.global @public_global : memref<3x7xf32> // CHECK-NEXT: emitc.global extern @public_global : !emitc.array<3x7xf32> memref.global @uninitialized_global : memref<3x7xf32> = uninitialized @@ -50,6 +52,8 @@ module @globals { func.func @use_global() { // CHECK-NEXT: emitc.get_global @public_global : !emitc.array<3x7xf32> %0 = memref.get_global @public_global : memref<3x7xf32> + // CHECK- NEXT: emitc.get_global @__constant_xi32 : !emitc.lvalue + %1 = memref.get_global @__constant_xi32 : memref return } }