Skip to content

Commit a469dc0

Browse files
committed
Started on the buffer_map
1 parent aeea062 commit a469dc0

File tree

2 files changed

+58
-5
lines changed

2 files changed

+58
-5
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1679,4 +1679,44 @@ def EmitC_GetFieldOp
16791679
let assemblyFormat = "$field_name `:` type($result) attr-dict";
16801680
}
16811681

1682+
def BufferMapOp
1683+
: EmitC_Op<"buffer_map", [Pure,
1684+
DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
1685+
let summary = "Creates a buffer map for field access";
1686+
let description = [{
1687+
The `emitc.buffer_map` operation generates a C++ std::map that maps field names
1688+
to their memory addresses for efficient runtime field access. This operation
1689+
collects fields with buffer attributes and creates the necessary lookup
1690+
infrastructure.
1691+
1692+
Example:
1693+
1694+
```mlir
1695+
emitc.buffer_map reflection_map [ @field1, @field2, @field3 ]
1696+
```
1697+
1698+
This generates C++ code like:
1699+
1700+
```cpp
1701+
const std::map<std::string, char*> _reflection_map {
1702+
{ "field1", reinterpret_cast<char*>(&field1) },
1703+
{ "field2", reinterpret_cast<char*>(&field2) },
1704+
{ "field3", reinterpret_cast<char*>(&field3) }
1705+
};
1706+
1707+
char* getBufferForName(const std::string& name) const {
1708+
auto it = _reflection_map.find(name);
1709+
return (it == _reflection_map.end()) ? nullptr : it->second;
1710+
}
1711+
```
1712+
}];
1713+
1714+
let arguments = (ins SymbolNameAttr:$sym_name,
1715+
Arg<OptionalAttr<ArrayAttr>, "field names">:$fields);
1716+
1717+
let results = (outs);
1718+
let builders = [];
1719+
let assemblyFormat = "$sym_name $fields attr-dict";
1720+
}
1721+
16821722
#endif // MLIR_DIALECT_EMITC_IR_EMITC

mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,22 +53,35 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
5353
ClassOp newClassOp = rewriter.create<ClassOp>(funcOp.getLoc(), className);
5454

5555
SmallVector<std::pair<StringAttr, TypeAttr>> fields;
56+
SmallVector<Attribute> bufferFieldAttrs;
5657
rewriter.createBlock(&newClassOp.getBody());
5758
rewriter.setInsertionPointToStart(&newClassOp.getBody().front());
5859

5960
auto argAttrs = funcOp.getArgAttrs();
6061
for (auto [idx, val] : llvm::enumerate(funcOp.getArguments())) {
6162
StringAttr fieldName;
62-
Attribute argAttr = nullptr;
6363

6464
fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx));
65-
if (argAttrs && idx < argAttrs->size())
66-
argAttr = (*argAttrs)[idx];
67-
65+
if (argAttrs && idx < argAttrs->size()) {
66+
mlir::DictionaryAttr dictAttr =
67+
dyn_cast_or_null<mlir::DictionaryAttr>((*argAttrs)[idx]);
68+
const mlir::Attribute namedAttribute =
69+
dictAttr.getNamed(attributeName)->getValue();
70+
71+
auto name =
72+
cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(namedAttribute)[0]);
73+
bufferFieldAttrs.push_back(name);
74+
}
6875
TypeAttr typeAttr = TypeAttr::get(val.getType());
6976
fields.push_back({fieldName, typeAttr});
7077
rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
71-
argAttr);
78+
nullptr);
79+
}
80+
81+
if (!bufferFieldAttrs.empty()) {
82+
ArrayAttr fieldsArrayAttr = rewriter.getArrayAttr(bufferFieldAttrs);
83+
rewriter.create<emitc::BufferMapOp>(funcOp.getLoc(), "reflection_map",
84+
fieldsArrayAttr);
7285
}
7386

7487
rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());

0 commit comments

Comments
 (0)