Skip to content

Commit 0317629

Browse files
committed
[mlir][tblgen] add concrete create methods
1 parent 543f948 commit 0317629

File tree

2 files changed

+57
-6
lines changed

2 files changed

+57
-6
lines changed

mlir/include/mlir/TableGen/Class.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ class MethodParameter {
7171
StringRef getName() const { return name; }
7272
/// Returns true if the parameter has a default value.
7373
bool hasDefaultValue() const { return !defaultValue.empty(); }
74+
StringRef getDefaultValue() const { return defaultValue; }
75+
bool isOptional() const { return optional; }
7476

7577
private:
7678
/// The C++ type.

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,14 @@ static const char *const opCommentHeader = R"(
230230
231231
)";
232232

233+
static const char *const inlineCreateBody = R"(
234+
::mlir::OperationState __state__({0}, getOperationName());
235+
build(builder, __state__{1});
236+
auto __res__ = ::llvm::dyn_cast<{2}>(builder.create(__state__));
237+
assert(__res__ && "builder didn't return the right type");
238+
return __res__;
239+
)";
240+
233241
//===----------------------------------------------------------------------===//
234242
// Utility structs and functions
235243
//===----------------------------------------------------------------------===//
@@ -665,6 +673,7 @@ class OpEmitter {
665673
// Generates the build() method that takes each operand/attribute
666674
// as a stand-alone parameter.
667675
void genSeparateArgParamBuilder();
676+
void genInlineCreateBody(const SmallVector<MethodParameter> &paramList);
668677

669678
// Generates the build() method that takes each operand/attribute as a
670679
// stand-alone parameter. The generated build() method uses first operand's
@@ -2568,6 +2577,39 @@ static bool canInferType(const Operator &op) {
25682577
return op.getTrait("::mlir::InferTypeOpInterface::Trait");
25692578
}
25702579

2580+
void OpEmitter::genInlineCreateBody(
2581+
const SmallVector<MethodParameter> &paramList) {
2582+
SmallVector<MethodParameter> createParamList;
2583+
SmallVector<llvm::StringRef, 4> nonBuilderStateArgsList;
2584+
createParamList.emplace_back("::mlir::OpBuilder &", "builder");
2585+
std::string locParamName = "location";
2586+
while (llvm::find_if(paramList, [&locParamName](const MethodParameter &p) {
2587+
return p.getName() == locParamName;
2588+
}) != paramList.end()) {
2589+
locParamName += "_";
2590+
}
2591+
createParamList.emplace_back("::mlir::Location", locParamName);
2592+
2593+
for (auto &param : paramList) {
2594+
if (param.getType() == "::mlir::OpBuilder &" ||
2595+
param.getType() == "::mlir::OperationState &")
2596+
continue;
2597+
createParamList.emplace_back(param.getType(), param.getName(),
2598+
param.getDefaultValue(), param.isOptional());
2599+
nonBuilderStateArgsList.push_back(param.getName());
2600+
}
2601+
auto *c = opClass.addStaticMethod(opClass.getClassName(), "create",
2602+
createParamList);
2603+
std::string nonBuilderStateArgs = "";
2604+
if (!nonBuilderStateArgsList.empty()) {
2605+
llvm::raw_string_ostream nonBuilderStateArgsOS(nonBuilderStateArgs);
2606+
interleaveComma(nonBuilderStateArgsList, nonBuilderStateArgsOS);
2607+
nonBuilderStateArgs = ", " + nonBuilderStateArgs;
2608+
}
2609+
c->body() << llvm::formatv(inlineCreateBody, locParamName,
2610+
nonBuilderStateArgs, opClass.getClassName());
2611+
}
2612+
25712613
void OpEmitter::genSeparateArgParamBuilder() {
25722614
SmallVector<AttrParamKind, 2> attrBuilderType;
25732615
attrBuilderType.push_back(AttrParamKind::WrappedAttr);
@@ -2584,10 +2626,12 @@ void OpEmitter::genSeparateArgParamBuilder() {
25842626
buildParamList(paramList, inferredAttributes, resultNames, paramKind,
25852627
attrType);
25862628

2587-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2629+
auto *m = opClass.addStaticMethod("void", "build", paramList);
25882630
// If the builder is redundant, skip generating the method.
25892631
if (!m)
25902632
return;
2633+
genInlineCreateBody(paramList);
2634+
25912635
auto &body = m->body();
25922636
genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
25932637
/*isRawValueAttr=*/attrType ==
@@ -2712,10 +2756,11 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder(
27122756
if (op.getNumVariadicRegions())
27132757
paramList.emplace_back("unsigned", "numRegions");
27142758

2715-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2759+
auto *m = opClass.addStaticMethod("void", "build", paramList);
27162760
// If the builder is redundant, skip generating the method
27172761
if (!m)
27182762
return;
2763+
genInlineCreateBody(paramList);
27192764
auto &body = m->body();
27202765

27212766
// Operands
@@ -2826,10 +2871,11 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder(
28262871
if (op.getNumVariadicRegions())
28272872
paramList.emplace_back("unsigned", "numRegions");
28282873

2829-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2874+
auto *m = opClass.addStaticMethod("void", "build", paramList);
28302875
// If the builder is redundant, skip generating the method
28312876
if (!m)
28322877
return;
2878+
genInlineCreateBody(paramList);
28332879
auto &body = m->body();
28342880

28352881
int numResults = op.getNumResults();
@@ -2906,10 +2952,11 @@ void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
29062952
buildParamList(paramList, inferredAttributes, resultNames,
29072953
TypeParamKind::None, attrType);
29082954

2909-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2955+
auto *m = opClass.addStaticMethod("void", "build", paramList);
29102956
// If the builder is redundant, skip generating the method
29112957
if (!m)
29122958
return;
2959+
genInlineCreateBody(paramList);
29132960
auto &body = m->body();
29142961
genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
29152962
/*isRawValueAttr=*/attrType ==
@@ -2948,10 +2995,11 @@ void OpEmitter::genUseAttrAsResultTypeCollectiveParamBuilder(
29482995
: "attributes";
29492996
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
29502997
attributesName, "{}");
2951-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2998+
auto *m = opClass.addStaticMethod("void", "build", paramList);
29522999
// If the builder is redundant, skip generating the method
29533000
if (!m)
29543001
return;
3002+
genInlineCreateBody(paramList);
29553003

29563004
auto &body = m->body();
29573005

@@ -3114,10 +3162,11 @@ void OpEmitter::genCollectiveParamBuilder(CollectiveBuilderKind kind) {
31143162
if (op.getNumVariadicRegions())
31153163
paramList.emplace_back("unsigned", "numRegions");
31163164

3117-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
3165+
auto *m = opClass.addStaticMethod("void", "build", paramList);
31183166
// If the builder is redundant, skip generating the method
31193167
if (!m)
31203168
return;
3169+
genInlineCreateBody(paramList);
31213170
auto &body = m->body();
31223171

31233172
// Operands

0 commit comments

Comments
 (0)