@@ -230,6 +230,14 @@ static const char *const opCommentHeader = R"(
230
230
231
231
)" ;
232
232
233
+ static const char *const inlineCreateBody = R"(
234
+ ::mlir::OperationState __state__({0}, getOperationName());
235
+ build(builder, __state__{1});
236
+ auto __res__ = ::llvm::dyn_cast<{2}>(builder.create(__state__));
237
+ assert(__res__ && "builder didn't return the right type");
238
+ return __res__;
239
+ )" ;
240
+
233
241
// ===----------------------------------------------------------------------===//
234
242
// Utility structs and functions
235
243
// ===----------------------------------------------------------------------===//
@@ -665,6 +673,7 @@ class OpEmitter {
665
673
// Generates the build() method that takes each operand/attribute
666
674
// as a stand-alone parameter.
667
675
void genSeparateArgParamBuilder ();
676
+ void genInlineCreateBody (const SmallVector<MethodParameter> ¶mList);
668
677
669
678
// Generates the build() method that takes each operand/attribute as a
670
679
// stand-alone parameter. The generated build() method uses first operand's
@@ -2557,6 +2566,40 @@ static bool canInferType(const Operator &op) {
2557
2566
return op.getTrait (" ::mlir::InferTypeOpInterface::Trait" );
2558
2567
}
2559
2568
2569
+ void OpEmitter::genInlineCreateBody (
2570
+ const SmallVector<MethodParameter> ¶mList) {
2571
+ SmallVector<MethodParameter> createParamList;
2572
+ SmallVector<llvm::StringRef, 4 > nonBuilderStateArgsList;
2573
+ createParamList.emplace_back (" ::mlir::OpBuilder &" , " builder" );
2574
+ std::string locParamName = " location" ;
2575
+ const MethodParameter *locCollision =
2576
+ llvm::find_if (paramList, [&locParamName](const MethodParameter &p) {
2577
+ return p.getName () == locParamName;
2578
+ });
2579
+ while (locCollision && locCollision->getName () == locParamName)
2580
+ locParamName += " _" ;
2581
+ createParamList.emplace_back (" ::mlir::Location" , locParamName);
2582
+
2583
+ for (auto ¶m : paramList) {
2584
+ if (param.getType () == " ::mlir::OpBuilder &" ||
2585
+ param.getType () == " ::mlir::OperationState &" )
2586
+ continue ;
2587
+ createParamList.emplace_back (param.getType (), param.getName (),
2588
+ param.getDefaultValue (), param.isOptional ());
2589
+ nonBuilderStateArgsList.push_back (param.getName ());
2590
+ }
2591
+ auto *c = opClass.addStaticMethod (opClass.getClassName (), " create" ,
2592
+ createParamList);
2593
+ std::string nonBuilderStateArgs = " " ;
2594
+ if (!nonBuilderStateArgsList.empty ()) {
2595
+ llvm::raw_string_ostream nonBuilderStateArgsOS (nonBuilderStateArgs);
2596
+ interleaveComma (nonBuilderStateArgsList, nonBuilderStateArgsOS);
2597
+ nonBuilderStateArgs = " , " + nonBuilderStateArgs;
2598
+ }
2599
+ c->body () << llvm::formatv (inlineCreateBody, locParamName,
2600
+ nonBuilderStateArgs, opClass.getClassName ());
2601
+ }
2602
+
2560
2603
void OpEmitter::genSeparateArgParamBuilder () {
2561
2604
SmallVector<AttrParamKind, 2 > attrBuilderType;
2562
2605
attrBuilderType.push_back (AttrParamKind::WrappedAttr);
@@ -2573,10 +2616,12 @@ void OpEmitter::genSeparateArgParamBuilder() {
2573
2616
buildParamList (paramList, inferredAttributes, resultNames, paramKind,
2574
2617
attrType);
2575
2618
2576
- auto *m = opClass.addStaticMethod (" void" , " build" , std::move ( paramList) );
2619
+ auto *m = opClass.addStaticMethod (" void" , " build" , paramList);
2577
2620
// If the builder is redundant, skip generating the method.
2578
2621
if (!m)
2579
2622
return ;
2623
+ genInlineCreateBody (paramList);
2624
+
2580
2625
auto &body = m->body ();
2581
2626
genCodeForAddingArgAndRegionForBuilder (body, inferredAttributes,
2582
2627
/* isRawValueAttr=*/ attrType ==
@@ -2701,10 +2746,11 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder(
2701
2746
if (op.getNumVariadicRegions ())
2702
2747
paramList.emplace_back (" unsigned" , " numRegions" );
2703
2748
2704
- auto *m = opClass.addStaticMethod (" void" , " build" , std::move ( paramList) );
2749
+ auto *m = opClass.addStaticMethod (" void" , " build" , paramList);
2705
2750
// If the builder is redundant, skip generating the method
2706
2751
if (!m)
2707
2752
return ;
2753
+ genInlineCreateBody (paramList);
2708
2754
auto &body = m->body ();
2709
2755
2710
2756
// Operands
@@ -2815,10 +2861,11 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder(
2815
2861
if (op.getNumVariadicRegions ())
2816
2862
paramList.emplace_back (" unsigned" , " numRegions" );
2817
2863
2818
- auto *m = opClass.addStaticMethod (" void" , " build" , std::move ( paramList) );
2864
+ auto *m = opClass.addStaticMethod (" void" , " build" , paramList);
2819
2865
// If the builder is redundant, skip generating the method
2820
2866
if (!m)
2821
2867
return ;
2868
+ genInlineCreateBody (paramList);
2822
2869
auto &body = m->body ();
2823
2870
2824
2871
int numResults = op.getNumResults ();
@@ -2895,10 +2942,11 @@ void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
2895
2942
buildParamList (paramList, inferredAttributes, resultNames,
2896
2943
TypeParamKind::None, attrType);
2897
2944
2898
- auto *m = opClass.addStaticMethod (" void" , " build" , std::move ( paramList) );
2945
+ auto *m = opClass.addStaticMethod (" void" , " build" , paramList);
2899
2946
// If the builder is redundant, skip generating the method
2900
2947
if (!m)
2901
2948
return ;
2949
+ genInlineCreateBody (paramList);
2902
2950
auto &body = m->body ();
2903
2951
genCodeForAddingArgAndRegionForBuilder (body, inferredAttributes,
2904
2952
/* isRawValueAttr=*/ attrType ==
@@ -2937,10 +2985,11 @@ void OpEmitter::genUseAttrAsResultTypeCollectiveParamBuilder(
2937
2985
: " attributes" ;
2938
2986
paramList.emplace_back (" ::llvm::ArrayRef<::mlir::NamedAttribute>" ,
2939
2987
attributesName, " {}" );
2940
- auto *m = opClass.addStaticMethod (" void" , " build" , std::move ( paramList) );
2988
+ auto *m = opClass.addStaticMethod (" void" , " build" , paramList);
2941
2989
// If the builder is redundant, skip generating the method
2942
2990
if (!m)
2943
2991
return ;
2992
+ genInlineCreateBody (paramList);
2944
2993
2945
2994
auto &body = m->body ();
2946
2995
@@ -3103,10 +3152,11 @@ void OpEmitter::genCollectiveParamBuilder(CollectiveBuilderKind kind) {
3103
3152
if (op.getNumVariadicRegions ())
3104
3153
paramList.emplace_back (" unsigned" , " numRegions" );
3105
3154
3106
- auto *m = opClass.addStaticMethod (" void" , " build" , std::move ( paramList) );
3155
+ auto *m = opClass.addStaticMethod (" void" , " build" , paramList);
3107
3156
// If the builder is redundant, skip generating the method
3108
3157
if (!m)
3109
3158
return ;
3159
+ genInlineCreateBody (paramList);
3110
3160
auto &body = m->body ();
3111
3161
3112
3162
// Operands
0 commit comments