Skip to content

Commit c99b3c0

Browse files
committed
[mlir][tblgen] add concrete create methods
1 parent 2dfcc43 commit c99b3c0

File tree

2 files changed

+57
-6
lines changed

2 files changed

+57
-6
lines changed

mlir/include/mlir/TableGen/Class.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ class MethodParameter {
7171
StringRef getName() const { return name; }
7272
/// Returns true if the parameter has a default value.
7373
bool hasDefaultValue() const { return !defaultValue.empty(); }
74+
StringRef getDefaultValue() const { return defaultValue; }
75+
bool isOptional() const { return optional; }
7476

7577
private:
7678
/// The C++ type.

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,14 @@ static const char *const opCommentHeader = R"(
230230
231231
)";
232232

233+
static const char *const inlineCreateBody = R"(
234+
::mlir::OperationState __state__({0}, getOperationName());
235+
build(builder, __state__{1});
236+
auto __res__ = ::llvm::dyn_cast<{2}>(builder.create(__state__));
237+
assert(__res__ && "builder didn't return the right type");
238+
return __res__;
239+
)";
240+
233241
//===----------------------------------------------------------------------===//
234242
// Utility structs and functions
235243
//===----------------------------------------------------------------------===//
@@ -665,6 +673,7 @@ class OpEmitter {
665673
// Generates the build() method that takes each operand/attribute
666674
// as a stand-alone parameter.
667675
void genSeparateArgParamBuilder();
676+
void genInlineCreateBody(const SmallVector<MethodParameter> &paramList);
668677

669678
// Generates the build() method that takes each operand/attribute as a
670679
// stand-alone parameter. The generated build() method uses first operand's
@@ -2557,6 +2566,39 @@ static bool canInferType(const Operator &op) {
25572566
return op.getTrait("::mlir::InferTypeOpInterface::Trait");
25582567
}
25592568

2569+
void OpEmitter::genInlineCreateBody(
2570+
const SmallVector<MethodParameter> &paramList) {
2571+
SmallVector<MethodParameter> createParamList;
2572+
SmallVector<llvm::StringRef, 4> nonBuilderStateArgsList;
2573+
createParamList.emplace_back("::mlir::OpBuilder &", "builder");
2574+
std::string locParamName = "location";
2575+
while (llvm::find_if(paramList, [&locParamName](const MethodParameter &p) {
2576+
return p.getName() == locParamName;
2577+
}) != paramList.end()) {
2578+
locParamName += "_";
2579+
}
2580+
createParamList.emplace_back("::mlir::Location", locParamName);
2581+
2582+
for (auto &param : paramList) {
2583+
if (param.getType() == "::mlir::OpBuilder &" ||
2584+
param.getType() == "::mlir::OperationState &")
2585+
continue;
2586+
createParamList.emplace_back(param.getType(), param.getName(),
2587+
param.getDefaultValue(), param.isOptional());
2588+
nonBuilderStateArgsList.push_back(param.getName());
2589+
}
2590+
auto *c = opClass.addStaticMethod(opClass.getClassName(), "create",
2591+
createParamList);
2592+
std::string nonBuilderStateArgs = "";
2593+
if (!nonBuilderStateArgsList.empty()) {
2594+
llvm::raw_string_ostream nonBuilderStateArgsOS(nonBuilderStateArgs);
2595+
interleaveComma(nonBuilderStateArgsList, nonBuilderStateArgsOS);
2596+
nonBuilderStateArgs = ", " + nonBuilderStateArgs;
2597+
}
2598+
c->body() << llvm::formatv(inlineCreateBody, locParamName,
2599+
nonBuilderStateArgs, opClass.getClassName());
2600+
}
2601+
25602602
void OpEmitter::genSeparateArgParamBuilder() {
25612603
SmallVector<AttrParamKind, 2> attrBuilderType;
25622604
attrBuilderType.push_back(AttrParamKind::WrappedAttr);
@@ -2573,10 +2615,12 @@ void OpEmitter::genSeparateArgParamBuilder() {
25732615
buildParamList(paramList, inferredAttributes, resultNames, paramKind,
25742616
attrType);
25752617

2576-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2618+
auto *m = opClass.addStaticMethod("void", "build", paramList);
25772619
// If the builder is redundant, skip generating the method.
25782620
if (!m)
25792621
return;
2622+
genInlineCreateBody(paramList);
2623+
25802624
auto &body = m->body();
25812625
genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
25822626
/*isRawValueAttr=*/attrType ==
@@ -2701,10 +2745,11 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder(
27012745
if (op.getNumVariadicRegions())
27022746
paramList.emplace_back("unsigned", "numRegions");
27032747

2704-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2748+
auto *m = opClass.addStaticMethod("void", "build", paramList);
27052749
// If the builder is redundant, skip generating the method
27062750
if (!m)
27072751
return;
2752+
genInlineCreateBody(paramList);
27082753
auto &body = m->body();
27092754

27102755
// Operands
@@ -2815,10 +2860,11 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder(
28152860
if (op.getNumVariadicRegions())
28162861
paramList.emplace_back("unsigned", "numRegions");
28172862

2818-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2863+
auto *m = opClass.addStaticMethod("void", "build", paramList);
28192864
// If the builder is redundant, skip generating the method
28202865
if (!m)
28212866
return;
2867+
genInlineCreateBody(paramList);
28222868
auto &body = m->body();
28232869

28242870
int numResults = op.getNumResults();
@@ -2895,10 +2941,11 @@ void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
28952941
buildParamList(paramList, inferredAttributes, resultNames,
28962942
TypeParamKind::None, attrType);
28972943

2898-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2944+
auto *m = opClass.addStaticMethod("void", "build", paramList);
28992945
// If the builder is redundant, skip generating the method
29002946
if (!m)
29012947
return;
2948+
genInlineCreateBody(paramList);
29022949
auto &body = m->body();
29032950
genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
29042951
/*isRawValueAttr=*/attrType ==
@@ -2937,10 +2984,11 @@ void OpEmitter::genUseAttrAsResultTypeCollectiveParamBuilder(
29372984
: "attributes";
29382985
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
29392986
attributesName, "{}");
2940-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2987+
auto *m = opClass.addStaticMethod("void", "build", paramList);
29412988
// If the builder is redundant, skip generating the method
29422989
if (!m)
29432990
return;
2991+
genInlineCreateBody(paramList);
29442992

29452993
auto &body = m->body();
29462994

@@ -3103,10 +3151,11 @@ void OpEmitter::genCollectiveParamBuilder(CollectiveBuilderKind kind) {
31033151
if (op.getNumVariadicRegions())
31043152
paramList.emplace_back("unsigned", "numRegions");
31053153

3106-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
3154+
auto *m = opClass.addStaticMethod("void", "build", paramList);
31073155
// If the builder is redundant, skip generating the method
31083156
if (!m)
31093157
return;
3158+
genInlineCreateBody(paramList);
31103159
auto &body = m->body();
31113160

31123161
// Operands

0 commit comments

Comments
 (0)