diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 91ee89919e58e..f5d163a45b08e 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -1644,17 +1644,14 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> { Example: ```mlir - // Example with an attribute: - emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.opaque = "another_feature"} // Example with no attribute: emitc.field @fieldName0 : !emitc.array<1xf32> ``` }]; - let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type, - OptionalAttr:$attrs); + let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type); - let assemblyFormat = [{ $sym_name `:` $type ($attrs^)? attr-dict}]; + let assemblyFormat = [{ $sym_name `:` $type attr-dict}]; let hasVerifier = 1; } @@ -1679,4 +1676,42 @@ def EmitC_GetFieldOp let assemblyFormat = "$field_name `:` type($result) attr-dict"; } +def BufferMapOp + : EmitC_Op<"buffer_map", [Pure, + DeclareOpInterfaceMethods]> { + let summary = "Creates a buffer map for field access"; + let description = [{ + The `emitc.buffer_map` operation generates a C++ std::map that maps field names + to their memory addresses for efficient runtime field access. This operation + collects fields with buffer attributes and creates the necessary lookup + infrastructure. + + Example: + + ```mlir + emitc.buffer_map [ @field1, @field2, @field3 ] + ``` + + This generates C++ code like: + + ```cpp + const std::map _buffer_map { + { "some_feature", reinterpret_cast(&some_feature) }, + { "another_feature", reinterpret_cast(&another_feature) }, + { "input_tense", reinterpret_cast(&input_tense) } + }; + + char* getBufferForName(const std::string& name) const { + auto it = _reflection_map.find(name); + return (it == _reflection_map.end()) ? nullptr : it->second; + } + ``` + }]; + + let arguments = (ins Arg, "field names">:$fields); + + let results = (outs); + let assemblyFormat = "$fields attr-dict"; +} + #endif // MLIR_DIALECT_EMITC_IR_EMITC diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 27298e892e599..b0ff4ed7bb688 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -1415,9 +1415,6 @@ LogicalResult FieldOp::verify() { if (!symName || symName.getValue().empty()) return emitOpError("field must have a non-empty symbol name"); - if (!getAttrs()) - return success(); - return success(); } diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp index 17d436f6df028..3e627708c6f3e 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp @@ -53,22 +53,32 @@ class WrapFuncInClass : public OpRewritePattern { ClassOp newClassOp = rewriter.create(funcOp.getLoc(), className); SmallVector> fields; + SmallVector bufferFieldAttrs; rewriter.createBlock(&newClassOp.getBody()); rewriter.setInsertionPointToStart(&newClassOp.getBody().front()); auto argAttrs = funcOp.getArgAttrs(); for (auto [idx, val] : llvm::enumerate(funcOp.getArguments())) { StringAttr fieldName; - Attribute argAttr = nullptr; fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx)); - if (argAttrs && idx < argAttrs->size()) - argAttr = (*argAttrs)[idx]; - + if (argAttrs && idx < argAttrs->size()) { + mlir::DictionaryAttr dictAttr = + dyn_cast_or_null((*argAttrs)[idx]); + const mlir::Attribute namedAttribute = + dictAttr.getNamed(attributeName)->getValue(); + + auto name = cast(namedAttribute); + bufferFieldAttrs.push_back(name); + } TypeAttr typeAttr = TypeAttr::get(val.getType()); fields.push_back({fieldName, typeAttr}); - rewriter.create(funcOp.getLoc(), fieldName, typeAttr, - argAttr); + rewriter.create(funcOp.getLoc(), fieldName, typeAttr); + } + + if (!bufferFieldAttrs.empty()) { + ArrayAttr fieldsArrayAttr = rewriter.getArrayAttr(bufferFieldAttrs); + rewriter.create(funcOp.getLoc(), fieldsArrayAttr); } rewriter.setInsertionPointToEnd(&newClassOp.getBody().front()); diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index c04548688bcf6..8fec3009ac9ef 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -24,6 +25,7 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/LogicalResult.h" #include #include @@ -1038,6 +1040,28 @@ static LogicalResult printOperation(CppEmitter &emitter, return success(); } +static LogicalResult printOperation(CppEmitter &emitter, + BufferMapOp bufferMapOp) { + raw_indented_ostream &os = emitter.ostream(); + os << "\nconst std::map _buffer_map {\n"; + os.indent(); + auto buf = bufferMapOp.getFields(); + for (auto field : *buf) { + os << "{ \"" << field << "\", reinterpret_cast(&" << field + << ") },\n"; + } + os.unindent(); + os << "};\n"; + + os << "char* getBufferForName(const std::string& name) const {\n"; + os.indent(); + os << "auto it = _buffer_map.find(name);\n"; + os << "return (it == _buffer_map.end()) ? nullptr : it->second;\n"; + os.unindent(); + os << "}\n"; + return success(); +} + static LogicalResult printOperation(CppEmitter &emitter, FileOp file) { if (!emitter.shouldEmitFile(file)) return success(); @@ -1645,17 +1669,17 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { .Case( + emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, + emitc::BufferMapOp, emitc::CallOp, emitc::CallOpaqueOp, + emitc::CastOp, emitc::ClassOp, emitc::CmpOp, + emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp, + emitc::DivOp, emitc::ExpressionOp, emitc::FieldOp, + emitc::FileOp, emitc::ForOp, emitc::FuncOp, emitc::GetFieldOp, + emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp, emitc::LoadOp, + emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp, + emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp, + emitc::SwitchOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp, + emitc::VariableOp, emitc::VerbatimOp>( [&](auto op) { return printOperation(*this, op); }) // Func ops. diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir index c67a0c197fcd9..364840adb3e5f 100644 --- a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir +++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir @@ -19,9 +19,10 @@ module attributes { } { // CHECK: module { // CHECK-NEXT: emitc.class @modelClass { -// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.name_hint = "another_feature"} -// CHECK-NEXT: emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.name_hint = "some_feature"} -// CHECK-NEXT: emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.name_hint = "output_0"} +// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> +// CHECK-NEXT: emitc.field @fieldName1 : !emitc.array<1xf32> +// CHECK-NEXT: emitc.field @fieldName2 : !emitc.array<1xf32> +// CHECK-NEXT: emitc.buffer_map ["another_feature", "some_feature", "output_0"] // CHECK-NEXT: emitc.func @execute() { // CHECK-NEXT: get_field @fieldName0 : !emitc.array<1xf32> // CHECK-NEXT: get_field @fieldName1 : !emitc.array<1xf32> diff --git a/mlir/test/mlir-translate/emitc_classops.mlir b/mlir/test/mlir-translate/emitc_classops.mlir index e42844412860e..f37aec457a051 100644 --- a/mlir/test/mlir-translate/emitc_classops.mlir +++ b/mlir/test/mlir-translate/emitc_classops.mlir @@ -3,6 +3,7 @@ emitc.class @modelClass { emitc.field @fieldName0 : !emitc.array<1xf32> emitc.field @fieldName1 : !emitc.array<1xf32> + emitc.buffer_map ["another_feature", "some_feature", "output_0"] emitc.func @execute() { %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t %1 = get_field @fieldName0 : !emitc.array<1xf32> @@ -16,6 +17,17 @@ emitc.class @modelClass { // CHECK-NEXT: public: // CHECK-NEXT: float[1] fieldName0; // CHECK-NEXT: float[1] fieldName1; +// CHECK-EMPTY: +// CHECK-NEXT: const std::map _buffer_map { +// CHECK-NEXT: { ""another_feature"", reinterpret_cast(&"another_feature") }, +// CHECK-NEXT: { ""some_feature"", reinterpret_cast(&"some_feature") }, +// CHECK-NEXT: { ""output_0"", reinterpret_cast(&"output_0") }, +// CHECK-NEXT: }; +// CHECK-NEXT: char* getBufferForName(const std::string& name) const { +// CHECK-NEXT: auto it = _buffer_map.find(name); +// CHECK-NEXT: return (it == _buffer_map.end()) ? nullptr : it->second; +// CHECK-NEXT: } +// CHECK-EMPTY: // CHECK-NEXT: void execute() { // CHECK-NEXT: size_t v1 = 0; // CHECK-NEXT: float[1] v2 = fieldName0;