Skip to content

Expand the MemRefToEmitC pass - Adding scalars #148055

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

Jaddyen
Copy link
Contributor

@Jaddyen Jaddyen commented Jul 10, 2025

This aims to expand the the MemRefToEmitC pass so that it can accept global scalars.

@Jaddyen Jaddyen changed the title Expand the MemRef to EmitC pass Expand the MemRef to EmitC pass - Adding scalars Jul 10, 2025
@Jaddyen Jaddyen changed the title Expand the MemRef to EmitC pass - Adding scalars Expand the MemRefToEmitC pass - Adding scalars Jul 10, 2025
@Jaddyen Jaddyen marked this pull request as ready for review July 10, 2025 23:50
@llvmbot
Copy link
Member

llvmbot commented Jul 10, 2025

@llvm/pr-subscribers-mlir-emitc

@llvm/pr-subscribers-mlir

Author: Jaden Angella (Jaddyen)

Changes

This aims to expand the the MemRefToEmitC pass so that it can accept global scalars.


Full diff: https://github.com/llvm/llvm-project/pull/148055.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp (+25-3)
  • (modified) mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir (+4)
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index db244d1d1cac8..e55c8e48ad105 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -16,7 +16,9 @@
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeRange.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
@@ -83,7 +85,7 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
   LogicalResult
   matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
                   ConversionPatternRewriter &rewriter) const override {
-
+    MemRefType type = op.getType();
     if (!op.getType().hasStaticShape()) {
       return rewriter.notifyMatchFailure(
           op.getLoc(), "cannot transform global with dynamic shape");
@@ -95,7 +97,13 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
           op.getLoc(), "global variable with alignment requirement is "
                        "currently not supported");
     }
-    auto resultTy = getTypeConverter()->convertType(op.getType());
+
+    Type resultTy;
+    if (type.getRank() == 0)
+      resultTy = getTypeConverter()->convertType(type.getElementType());
+    else
+      resultTy = getTypeConverter()->convertType(type);
+
     if (!resultTy) {
       return rewriter.notifyMatchFailure(op.getLoc(),
                                          "cannot convert result type");
@@ -114,6 +122,10 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
     bool externSpecifier = !staticSpecifier;
 
     Attribute initialValue = operands.getInitialValueAttr();
+    if (type.getRank() == 0) {
+      auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue());
+      initialValue = elementsAttr.getSplatValue<Attribute>();
+    }
     if (isa_and_present<UnitAttr>(initialValue))
       initialValue = {};
 
@@ -132,7 +144,17 @@ struct ConvertGetGlobal final
   matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
                   ConversionPatternRewriter &rewriter) const override {
 
-    auto resultTy = getTypeConverter()->convertType(op.getType());
+    MemRefType type = op.getType();
+    Type resultTy;
+    if (type.getRank() == 0)
+      resultTy = emitc::LValueType::get(
+          getTypeConverter()->convertType(type.getElementType()));
+    else
+      resultTy = getTypeConverter()->convertType(type);
+
+    if (!resultTy)
+      return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
+
     if (!resultTy) {
       return rewriter.notifyMatchFailure(op.getLoc(),
                                          "cannot convert result type");
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index d37fd1de90add..445a28534325a 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -41,6 +41,8 @@ func.func @memref_load(%buff : memref<4x8xf32>, %i: index, %j: index) -> f32 {
 module @globals {
   memref.global "private" constant @internal_global : memref<3x7xf32> = dense<4.0>
   // CHECK-NEXT: emitc.global static const @internal_global : !emitc.array<3x7xf32> = dense<4.000000e+00>
+  memref.global "private" constant @__constant_xi32 : memref<i32> = dense<-1>
+  // CHECK-NEXT: emitc.global static const @__constant_xi32 : i32 = -1
   memref.global @public_global : memref<3x7xf32>
   // CHECK-NEXT: emitc.global extern @public_global : !emitc.array<3x7xf32>
   memref.global @uninitialized_global : memref<3x7xf32> = uninitialized
@@ -50,6 +52,8 @@ module @globals {
   func.func @use_global() {
     // CHECK-NEXT: emitc.get_global @public_global : !emitc.array<3x7xf32>
     %0 = memref.get_global @public_global : memref<3x7xf32>
+    // CHECK- NEXT: emitc.get_global @__constant_xi32 : !emitc.lvalue<i32>
+    %1 = memref.get_global @__constant_xi32 : memref<i32>
     return
   }
 }

@@ -83,7 +85,7 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
LogicalResult
matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {

MemRefType type = op.getType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
MemRefType type = op.getType();
MemRefType opTy = op.getType();

I'm not sure we want to use type as a variable name...

Comment on lines +102 to +105
if (type.getRank() == 0)
resultTy = getTypeConverter()->convertType(type.getElementType());
else
resultTy = getTypeConverter()->convertType(type);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe introduce a helper, since I see a similar pattern in a few spots?

@@ -41,6 +41,8 @@ func.func @memref_load(%buff : memref<4x8xf32>, %i: index, %j: index) -> f32 {
module @globals {
memref.global "private" constant @internal_global : memref<3x7xf32> = dense<4.0>
// CHECK-NEXT: emitc.global static const @internal_global : !emitc.array<3x7xf32> = dense<4.000000e+00>
memref.global "private" constant @__constant_xi32 : memref<i32> = dense<-1>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was expecting to see a corresponding change to memref-to-emitc-failed.mlir. Looking I suppose it isn't there, but are any of the cases, like

going to work now?


Type resultTy;
if (type.getRank() == 0)
resultTy = getTypeConverter()->convertType(type.getElementType());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to just do

resultTy = getTypeConverter()->convertType(getElementTypeOrSelf(type));

@@ -50,6 +52,8 @@ module @globals {
func.func @use_global() {
// CHECK-NEXT: emitc.get_global @public_global : !emitc.array<3x7xf32>
%0 = memref.get_global @public_global : memref<3x7xf32>
// CHECK- NEXT: emitc.get_global @__constant_xi32 : !emitc.lvalue<i32>
%1 = memref.get_global @__constant_xi32 : memref<i32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to check, so previously memref<1xi32> worked, but memref<i32> didn't work before this?

@@ -114,6 +122,10 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
bool externSpecifier = !staticSpecifier;

Attribute initialValue = operands.getInitialValueAttr();
if (type.getRank() == 0) {
auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue());
initialValue = elementsAttr.getSplatValue<Attribute>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the initialValue before this vs splat value returned?

MemRefType type = op.getType();
Type resultTy;
if (type.getRank() == 0)
resultTy = emitc::LValueType::get(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need to be LValue, while one below not?

resultTy = getTypeConverter()->convertType(type);

if (!resultTy)
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you have this check just below too?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants