Skip to content

Commit ae157aa

Browse files
committed
[mlir][tblgen] add concrete create methods
1 parent 2dfcc43 commit ae157aa

File tree

2 files changed

+58
-6
lines changed

2 files changed

+58
-6
lines changed

mlir/include/mlir/TableGen/Class.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ class MethodParameter {
7171
StringRef getName() const { return name; }
7272
/// Returns true if the parameter has a default value.
7373
bool hasDefaultValue() const { return !defaultValue.empty(); }
74+
StringRef getDefaultValue() const { return defaultValue; }
75+
bool isOptional() const { return optional; }
7476

7577
private:
7678
/// The C++ type.

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,14 @@ static const char *const opCommentHeader = R"(
230230
231231
)";
232232

233+
static const char *const inlineCreateBody = R"(
234+
::mlir::OperationState __state__({0}, getOperationName());
235+
build(builder, __state__{1});
236+
auto __res__ = ::llvm::dyn_cast<{2}>(builder.create(__state__));
237+
assert(__res__ && "builder didn't return the right type");
238+
return __res__;
239+
)";
240+
233241
//===----------------------------------------------------------------------===//
234242
// Utility structs and functions
235243
//===----------------------------------------------------------------------===//
@@ -665,6 +673,7 @@ class OpEmitter {
665673
// Generates the build() method that takes each operand/attribute
666674
// as a stand-alone parameter.
667675
void genSeparateArgParamBuilder();
676+
void genInlineCreateBody(const SmallVector<MethodParameter> &paramList);
668677

669678
// Generates the build() method that takes each operand/attribute as a
670679
// stand-alone parameter. The generated build() method uses first operand's
@@ -2557,6 +2566,40 @@ static bool canInferType(const Operator &op) {
25572566
return op.getTrait("::mlir::InferTypeOpInterface::Trait");
25582567
}
25592568

2569+
void OpEmitter::genInlineCreateBody(
2570+
const SmallVector<MethodParameter> &paramList) {
2571+
SmallVector<MethodParameter> createParamList;
2572+
SmallVector<llvm::StringRef, 4> nonBuilderStateArgsList;
2573+
createParamList.emplace_back("::mlir::OpBuilder &", "builder");
2574+
std::string locParamName = "location";
2575+
const MethodParameter *locCollision =
2576+
llvm::find_if(paramList, [&locParamName](const MethodParameter &p) {
2577+
return p.getName() == locParamName;
2578+
});
2579+
while (locCollision && locCollision->getName() == locParamName)
2580+
locParamName += "_";
2581+
createParamList.emplace_back("::mlir::Location", locParamName);
2582+
2583+
for (auto &param : paramList) {
2584+
if (param.getType() == "::mlir::OpBuilder &" ||
2585+
param.getType() == "::mlir::OperationState &")
2586+
continue;
2587+
createParamList.emplace_back(param.getType(), param.getName(),
2588+
param.getDefaultValue(), param.isOptional());
2589+
nonBuilderStateArgsList.push_back(param.getName());
2590+
}
2591+
auto *c = opClass.addStaticMethod(opClass.getClassName(), "create",
2592+
createParamList);
2593+
std::string nonBuilderStateArgs = "";
2594+
if (!nonBuilderStateArgsList.empty()) {
2595+
llvm::raw_string_ostream nonBuilderStateArgsOS(nonBuilderStateArgs);
2596+
interleaveComma(nonBuilderStateArgsList, nonBuilderStateArgsOS);
2597+
nonBuilderStateArgs = ", " + nonBuilderStateArgs;
2598+
}
2599+
c->body() << llvm::formatv(inlineCreateBody, locParamName,
2600+
nonBuilderStateArgs, opClass.getClassName());
2601+
}
2602+
25602603
void OpEmitter::genSeparateArgParamBuilder() {
25612604
SmallVector<AttrParamKind, 2> attrBuilderType;
25622605
attrBuilderType.push_back(AttrParamKind::WrappedAttr);
@@ -2573,10 +2616,12 @@ void OpEmitter::genSeparateArgParamBuilder() {
25732616
buildParamList(paramList, inferredAttributes, resultNames, paramKind,
25742617
attrType);
25752618

2576-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2619+
auto *m = opClass.addStaticMethod("void", "build", paramList);
25772620
// If the builder is redundant, skip generating the method.
25782621
if (!m)
25792622
return;
2623+
genInlineCreateBody(paramList);
2624+
25802625
auto &body = m->body();
25812626
genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
25822627
/*isRawValueAttr=*/attrType ==
@@ -2701,10 +2746,11 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder(
27012746
if (op.getNumVariadicRegions())
27022747
paramList.emplace_back("unsigned", "numRegions");
27032748

2704-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2749+
auto *m = opClass.addStaticMethod("void", "build", paramList);
27052750
// If the builder is redundant, skip generating the method
27062751
if (!m)
27072752
return;
2753+
genInlineCreateBody(paramList);
27082754
auto &body = m->body();
27092755

27102756
// Operands
@@ -2815,10 +2861,11 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder(
28152861
if (op.getNumVariadicRegions())
28162862
paramList.emplace_back("unsigned", "numRegions");
28172863

2818-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2864+
auto *m = opClass.addStaticMethod("void", "build", paramList);
28192865
// If the builder is redundant, skip generating the method
28202866
if (!m)
28212867
return;
2868+
genInlineCreateBody(paramList);
28222869
auto &body = m->body();
28232870

28242871
int numResults = op.getNumResults();
@@ -2895,10 +2942,11 @@ void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
28952942
buildParamList(paramList, inferredAttributes, resultNames,
28962943
TypeParamKind::None, attrType);
28972944

2898-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2945+
auto *m = opClass.addStaticMethod("void", "build", paramList);
28992946
// If the builder is redundant, skip generating the method
29002947
if (!m)
29012948
return;
2949+
genInlineCreateBody(paramList);
29022950
auto &body = m->body();
29032951
genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
29042952
/*isRawValueAttr=*/attrType ==
@@ -2937,10 +2985,11 @@ void OpEmitter::genUseAttrAsResultTypeCollectiveParamBuilder(
29372985
: "attributes";
29382986
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
29392987
attributesName, "{}");
2940-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2988+
auto *m = opClass.addStaticMethod("void", "build", paramList);
29412989
// If the builder is redundant, skip generating the method
29422990
if (!m)
29432991
return;
2992+
genInlineCreateBody(paramList);
29442993

29452994
auto &body = m->body();
29462995

@@ -3103,10 +3152,11 @@ void OpEmitter::genCollectiveParamBuilder(CollectiveBuilderKind kind) {
31033152
if (op.getNumVariadicRegions())
31043153
paramList.emplace_back("unsigned", "numRegions");
31053154

3106-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
3155+
auto *m = opClass.addStaticMethod("void", "build", paramList);
31073156
// If the builder is redundant, skip generating the method
31083157
if (!m)
31093158
return;
3159+
genInlineCreateBody(paramList);
31103160
auto &body = m->body();
31113161

31123162
// Operands

0 commit comments

Comments
 (0)