From 127a1ed01365f40c719a430d3d6ae56390728bb3 Mon Sep 17 00:00:00 2001 From: Jaddyen Date: Mon, 7 Jul 2025 18:49:20 +0000 Subject: [PATCH 1/3] Initial work on memref ops --- .../MemRefToEmitC/MemRefToEmitC.cpp | 37 ++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index db244d1d1cac8..742d2bfff27de 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -77,6 +77,31 @@ struct ConvertAlloca final : public OpConversionPattern { } }; +struct ConvertCopy final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::CopyOp op, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + return failure(); + } +}; + +static Type convertGlobalMemrefTypeToEmitc(MemRefType type, + const TypeConverter &typeConverter) { + Type elementType = typeConverter.convertType(type.getElementType()); + Type arrayTy = elementType; + // Shape has the outermost dim at index 0, so need to walk it backwards + auto shape = type.getShape(); + if (shape.empty()) { + arrayTy = emitc::ArrayType::get({1}, arrayTy); + } else { + // For non-zero dimensions, use the original shape + arrayTy = emitc::ArrayType::get(shape, arrayTy); + } + return arrayTy; +} + struct ConvertGlobal final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -95,7 +120,8 @@ struct ConvertGlobal final : public OpConversionPattern { op.getLoc(), "global variable with alignment requirement is " "currently not supported"); } - auto resultTy = getTypeConverter()->convertType(op.getType()); + auto resultTy = + convertGlobalMemrefTypeToEmitc(op.getType(), *getTypeConverter()); if (!resultTy) { return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert result type"); @@ -114,6 +140,15 @@ struct ConvertGlobal final : public OpConversionPattern { bool externSpecifier = !staticSpecifier; Attribute initialValue = operands.getInitialValueAttr(); + if (op.getType().getRank() == 0) { + auto elementsAttr = llvm::cast(*op.getInitialValue()); + auto scalarValue = elementsAttr.getSplatValue(); + + // Convert scalar value to single-element array + initialValue = DenseElementsAttr::get( + RankedTensorType::get({1}, elementsAttr.getElementType()), + {scalarValue}); + } if (isa_and_present(initialValue)) initialValue = {}; From c2372515aaeafba1d11293ecdc7cf013c130b9e5 Mon Sep 17 00:00:00 2001 From: Jaddyen Date: Thu, 10 Jul 2025 20:36:08 +0000 Subject: [PATCH 2/3] Convert scalars to constants --- .../MemRefToEmitC/MemRefToEmitC.cpp | 43 ++++++++----------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 742d2bfff27de..f69a362395ef6 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -87,28 +87,13 @@ struct ConvertCopy final : public OpConversionPattern { } }; -static Type convertGlobalMemrefTypeToEmitc(MemRefType type, - const TypeConverter &typeConverter) { - Type elementType = typeConverter.convertType(type.getElementType()); - Type arrayTy = elementType; - // Shape has the outermost dim at index 0, so need to walk it backwards - auto shape = type.getShape(); - if (shape.empty()) { - arrayTy = emitc::ArrayType::get({1}, arrayTy); - } else { - // For non-zero dimensions, use the original shape - arrayTy = emitc::ArrayType::get(shape, arrayTy); - } - return arrayTy; -} - struct ConvertGlobal final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::GlobalOp op, OpAdaptor operands, ConversionPatternRewriter &rewriter) const override { - + auto type = op.getType(); if (!op.getType().hasStaticShape()) { return rewriter.notifyMatchFailure( op.getLoc(), "cannot transform global with dynamic shape"); @@ -120,8 +105,23 @@ struct ConvertGlobal final : public OpConversionPattern { op.getLoc(), "global variable with alignment requirement is " "currently not supported"); } - auto resultTy = - convertGlobalMemrefTypeToEmitc(op.getType(), *getTypeConverter()); + // auto resultTy = + // convertGlobalMemrefTypeToEmitc(op.getType(), *getTypeConverter()); + Type resultTy; + Type elementType = getTypeConverter()->convertType(type.getElementType()); + auto shape = type.getShape(); + + if (shape.empty()) { + if (emitc::isSupportedFloatType(elementType)) { + resultTy = rewriter.getF32Type(); + } + if (emitc::isSupportedIntegerType(elementType)) { + resultTy = rewriter.getIntegerType(elementType.getIntOrFloatBitWidth()); + } + } else { + resultTy = emitc::ArrayType::get(shape, elementType); + } + if (!resultTy) { return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert result type"); @@ -142,12 +142,7 @@ struct ConvertGlobal final : public OpConversionPattern { Attribute initialValue = operands.getInitialValueAttr(); if (op.getType().getRank() == 0) { auto elementsAttr = llvm::cast(*op.getInitialValue()); - auto scalarValue = elementsAttr.getSplatValue(); - - // Convert scalar value to single-element array - initialValue = DenseElementsAttr::get( - RankedTensorType::get({1}, elementsAttr.getElementType()), - {scalarValue}); + initialValue = elementsAttr.getSplatValue(); } if (isa_and_present(initialValue)) initialValue = {}; From 36b61a6b5731a524b4e79e77d7505c7a5ef3d0f9 Mon Sep 17 00:00:00 2001 From: Jaddyen Date: Thu, 10 Jul 2025 23:47:46 +0000 Subject: [PATCH 3/3] global and getGlobal --- .../MemRefToEmitC/MemRefToEmitC.cpp | 48 ++++++++----------- .../MemRefToEmitC/memref-to-emitc.mlir | 4 ++ 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index f69a362395ef6..e55c8e48ad105 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -16,7 +16,9 @@ #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeRange.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; @@ -77,23 +79,13 @@ struct ConvertAlloca final : public OpConversionPattern { } }; -struct ConvertCopy final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(memref::CopyOp op, OpAdaptor operands, - ConversionPatternRewriter &rewriter) const override { - return failure(); - } -}; - struct ConvertGlobal final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::GlobalOp op, OpAdaptor operands, ConversionPatternRewriter &rewriter) const override { - auto type = op.getType(); + MemRefType type = op.getType(); if (!op.getType().hasStaticShape()) { return rewriter.notifyMatchFailure( op.getLoc(), "cannot transform global with dynamic shape"); @@ -105,22 +97,12 @@ struct ConvertGlobal final : public OpConversionPattern { op.getLoc(), "global variable with alignment requirement is " "currently not supported"); } - // auto resultTy = - // convertGlobalMemrefTypeToEmitc(op.getType(), *getTypeConverter()); + Type resultTy; - Type elementType = getTypeConverter()->convertType(type.getElementType()); - auto shape = type.getShape(); - - if (shape.empty()) { - if (emitc::isSupportedFloatType(elementType)) { - resultTy = rewriter.getF32Type(); - } - if (emitc::isSupportedIntegerType(elementType)) { - resultTy = rewriter.getIntegerType(elementType.getIntOrFloatBitWidth()); - } - } else { - resultTy = emitc::ArrayType::get(shape, elementType); - } + if (type.getRank() == 0) + resultTy = getTypeConverter()->convertType(type.getElementType()); + else + resultTy = getTypeConverter()->convertType(type); if (!resultTy) { return rewriter.notifyMatchFailure(op.getLoc(), @@ -140,7 +122,7 @@ struct ConvertGlobal final : public OpConversionPattern { bool externSpecifier = !staticSpecifier; Attribute initialValue = operands.getInitialValueAttr(); - if (op.getType().getRank() == 0) { + if (type.getRank() == 0) { auto elementsAttr = llvm::cast(*op.getInitialValue()); initialValue = elementsAttr.getSplatValue(); } @@ -162,7 +144,17 @@ struct ConvertGetGlobal final matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands, ConversionPatternRewriter &rewriter) const override { - auto resultTy = getTypeConverter()->convertType(op.getType()); + MemRefType type = op.getType(); + Type resultTy; + if (type.getRank() == 0) + resultTy = emitc::LValueType::get( + getTypeConverter()->convertType(type.getElementType())); + else + resultTy = getTypeConverter()->convertType(type); + + if (!resultTy) + return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type"); + if (!resultTy) { return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert result type"); diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir index d37fd1de90add..445a28534325a 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir @@ -41,6 +41,8 @@ func.func @memref_load(%buff : memref<4x8xf32>, %i: index, %j: index) -> f32 { module @globals { memref.global "private" constant @internal_global : memref<3x7xf32> = dense<4.0> // CHECK-NEXT: emitc.global static const @internal_global : !emitc.array<3x7xf32> = dense<4.000000e+00> + memref.global "private" constant @__constant_xi32 : memref = dense<-1> + // CHECK-NEXT: emitc.global static const @__constant_xi32 : i32 = -1 memref.global @public_global : memref<3x7xf32> // CHECK-NEXT: emitc.global extern @public_global : !emitc.array<3x7xf32> memref.global @uninitialized_global : memref<3x7xf32> = uninitialized @@ -50,6 +52,8 @@ module @globals { func.func @use_global() { // CHECK-NEXT: emitc.get_global @public_global : !emitc.array<3x7xf32> %0 = memref.get_global @public_global : memref<3x7xf32> + // CHECK- NEXT: emitc.get_global @__constant_xi32 : !emitc.lvalue + %1 = memref.get_global @__constant_xi32 : memref return } }