diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 91ee89919e58e..f5d163a45b08e 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -1644,17 +1644,14 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> { Example: ```mlir - // Example with an attribute: - emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.opaque = "another_feature"} // Example with no attribute: emitc.field @fieldName0 : !emitc.array<1xf32> ``` }]; - let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type, - OptionalAttr:$attrs); + let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type); - let assemblyFormat = [{ $sym_name `:` $type ($attrs^)? attr-dict}]; + let assemblyFormat = [{ $sym_name `:` $type attr-dict}]; let hasVerifier = 1; } @@ -1679,4 +1676,42 @@ def EmitC_GetFieldOp let assemblyFormat = "$field_name `:` type($result) attr-dict"; } +def BufferMapOp + : EmitC_Op<"buffer_map", [Pure, + DeclareOpInterfaceMethods]> { + let summary = "Creates a buffer map for field access"; + let description = [{ + The `emitc.buffer_map` operation generates a C++ std::map that maps field names + to their memory addresses for efficient runtime field access. This operation + collects fields with buffer attributes and creates the necessary lookup + infrastructure. + + Example: + + ```mlir + emitc.buffer_map [ @field1, @field2, @field3 ] + ``` + + This generates C++ code like: + + ```cpp + const std::map _buffer_map { + { "some_feature", reinterpret_cast(&some_feature) }, + { "another_feature", reinterpret_cast(&another_feature) }, + { "input_tense", reinterpret_cast(&input_tense) } + }; + + char* getBufferForName(const std::string& name) const { + auto it = _reflection_map.find(name); + return (it == _reflection_map.end()) ? nullptr : it->second; + } + ``` + }]; + + let arguments = (ins Arg, "field names">:$fields); + + let results = (outs); + let assemblyFormat = "$fields attr-dict"; +} + #endif // MLIR_DIALECT_EMITC_IR_EMITC diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 27298e892e599..b0ff4ed7bb688 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -1415,9 +1415,6 @@ LogicalResult FieldOp::verify() { if (!symName || symName.getValue().empty()) return emitOpError("field must have a non-empty symbol name"); - if (!getAttrs()) - return success(); - return success(); } diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp index 17d436f6df028..3e627708c6f3e 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp @@ -53,22 +53,32 @@ class WrapFuncInClass : public OpRewritePattern { ClassOp newClassOp = rewriter.create(funcOp.getLoc(), className); SmallVector> fields; + SmallVector bufferFieldAttrs; rewriter.createBlock(&newClassOp.getBody()); rewriter.setInsertionPointToStart(&newClassOp.getBody().front()); auto argAttrs = funcOp.getArgAttrs(); for (auto [idx, val] : llvm::enumerate(funcOp.getArguments())) { StringAttr fieldName; - Attribute argAttr = nullptr; fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx)); - if (argAttrs && idx < argAttrs->size()) - argAttr = (*argAttrs)[idx]; - + if (argAttrs && idx < argAttrs->size()) { + mlir::DictionaryAttr dictAttr = + dyn_cast_or_null((*argAttrs)[idx]); + const mlir::Attribute namedAttribute = + dictAttr.getNamed(attributeName)->getValue(); + + auto name = cast(namedAttribute); + bufferFieldAttrs.push_back(name); + } TypeAttr typeAttr = TypeAttr::get(val.getType()); fields.push_back({fieldName, typeAttr}); - rewriter.create(funcOp.getLoc(), fieldName, typeAttr, - argAttr); + rewriter.create(funcOp.getLoc(), fieldName, typeAttr); + } + + if (!bufferFieldAttrs.empty()) { + ArrayAttr fieldsArrayAttr = rewriter.getArrayAttr(bufferFieldAttrs); + rewriter.create(funcOp.getLoc(), fieldsArrayAttr); } rewriter.setInsertionPointToEnd(&newClassOp.getBody().front()); diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir index c67a0c197fcd9..364840adb3e5f 100644 --- a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir +++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir @@ -19,9 +19,10 @@ module attributes { } { // CHECK: module { // CHECK-NEXT: emitc.class @modelClass { -// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.name_hint = "another_feature"} -// CHECK-NEXT: emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.name_hint = "some_feature"} -// CHECK-NEXT: emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.name_hint = "output_0"} +// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> +// CHECK-NEXT: emitc.field @fieldName1 : !emitc.array<1xf32> +// CHECK-NEXT: emitc.field @fieldName2 : !emitc.array<1xf32> +// CHECK-NEXT: emitc.buffer_map ["another_feature", "some_feature", "output_0"] // CHECK-NEXT: emitc.func @execute() { // CHECK-NEXT: get_field @fieldName0 : !emitc.array<1xf32> // CHECK-NEXT: get_field @fieldName1 : !emitc.array<1xf32>