@@ -230,6 +230,14 @@ static const char *const opCommentHeader = R"(
230
230
231
231
)" ;
232
232
233
+ static const char *const inlineCreateBody = R"(
234
+ ::mlir::OperationState __state__({0}, getOperationName());
235
+ build(builder, __state__{1});
236
+ auto __res__ = ::llvm::dyn_cast<{2}>(builder.create(__state__));
237
+ assert(__res__ && "builder didn't return the right type");
238
+ return __res__;
239
+ )" ;
240
+
233
241
// ===----------------------------------------------------------------------===//
234
242
// Utility structs and functions
235
243
// ===----------------------------------------------------------------------===//
@@ -665,6 +673,7 @@ class OpEmitter {
665
673
// Generates the build() method that takes each operand/attribute
666
674
// as a stand-alone parameter.
667
675
void genSeparateArgParamBuilder ();
676
+ void genInlineCreateBody (const SmallVector<MethodParameter> ¶mList);
668
677
669
678
// Generates the build() method that takes each operand/attribute as a
670
679
// stand-alone parameter. The generated build() method uses first operand's
@@ -2568,6 +2577,39 @@ static bool canInferType(const Operator &op) {
2568
2577
return op.getTrait (" ::mlir::InferTypeOpInterface::Trait" );
2569
2578
}
2570
2579
2580
+ void OpEmitter::genInlineCreateBody (
2581
+ const SmallVector<MethodParameter> ¶mList) {
2582
+ SmallVector<MethodParameter> createParamList;
2583
+ SmallVector<llvm::StringRef, 4 > nonBuilderStateArgsList;
2584
+ createParamList.emplace_back (" ::mlir::OpBuilder &" , " builder" );
2585
+ std::string locParamName = " location" ;
2586
+ while (llvm::find_if (paramList, [&locParamName](const MethodParameter &p) {
2587
+ return p.getName () == locParamName;
2588
+ }) != paramList.end ()) {
2589
+ locParamName += " _" ;
2590
+ }
2591
+ createParamList.emplace_back (" ::mlir::Location" , locParamName);
2592
+
2593
+ for (auto ¶m : paramList) {
2594
+ if (param.getType () == " ::mlir::OpBuilder &" ||
2595
+ param.getType () == " ::mlir::OperationState &" )
2596
+ continue ;
2597
+ createParamList.emplace_back (param.getType (), param.getName (),
2598
+ param.getDefaultValue (), param.isOptional ());
2599
+ nonBuilderStateArgsList.push_back (param.getName ());
2600
+ }
2601
+ auto *c = opClass.addStaticMethod (opClass.getClassName (), " create" ,
2602
+ createParamList);
2603
+ std::string nonBuilderStateArgs = " " ;
2604
+ if (!nonBuilderStateArgsList.empty ()) {
2605
+ llvm::raw_string_ostream nonBuilderStateArgsOS (nonBuilderStateArgs);
2606
+ interleaveComma (nonBuilderStateArgsList, nonBuilderStateArgsOS);
2607
+ nonBuilderStateArgs = " , " + nonBuilderStateArgs;
2608
+ }
2609
+ c->body () << llvm::formatv (inlineCreateBody, locParamName,
2610
+ nonBuilderStateArgs, opClass.getClassName ());
2611
+ }
2612
+
2571
2613
void OpEmitter::genSeparateArgParamBuilder () {
2572
2614
SmallVector<AttrParamKind, 2 > attrBuilderType;
2573
2615
attrBuilderType.push_back (AttrParamKind::WrappedAttr);
@@ -2584,10 +2626,12 @@ void OpEmitter::genSeparateArgParamBuilder() {
2584
2626
buildParamList (paramList, inferredAttributes, resultNames, paramKind,
2585
2627
attrType);
2586
2628
2587
- auto *m = opClass.addStaticMethod (" void" , " build" , std::move ( paramList) );
2629
+ auto *m = opClass.addStaticMethod (" void" , " build" , paramList);
2588
2630
// If the builder is redundant, skip generating the method.
2589
2631
if (!m)
2590
2632
return ;
2633
+ genInlineCreateBody (paramList);
2634
+
2591
2635
auto &body = m->body ();
2592
2636
genCodeForAddingArgAndRegionForBuilder (body, inferredAttributes,
2593
2637
/* isRawValueAttr=*/ attrType ==
@@ -2712,10 +2756,11 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder(
2712
2756
if (op.getNumVariadicRegions ())
2713
2757
paramList.emplace_back (" unsigned" , " numRegions" );
2714
2758
2715
- auto *m = opClass.addStaticMethod (" void" , " build" , std::move ( paramList) );
2759
+ auto *m = opClass.addStaticMethod (" void" , " build" , paramList);
2716
2760
// If the builder is redundant, skip generating the method
2717
2761
if (!m)
2718
2762
return ;
2763
+ genInlineCreateBody (paramList);
2719
2764
auto &body = m->body ();
2720
2765
2721
2766
// Operands
@@ -2826,10 +2871,11 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder(
2826
2871
if (op.getNumVariadicRegions ())
2827
2872
paramList.emplace_back (" unsigned" , " numRegions" );
2828
2873
2829
- auto *m = opClass.addStaticMethod (" void" , " build" , std::move ( paramList) );
2874
+ auto *m = opClass.addStaticMethod (" void" , " build" , paramList);
2830
2875
// If the builder is redundant, skip generating the method
2831
2876
if (!m)
2832
2877
return ;
2878
+ genInlineCreateBody (paramList);
2833
2879
auto &body = m->body ();
2834
2880
2835
2881
int numResults = op.getNumResults ();
@@ -2906,10 +2952,11 @@ void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
2906
2952
buildParamList (paramList, inferredAttributes, resultNames,
2907
2953
TypeParamKind::None, attrType);
2908
2954
2909
- auto *m = opClass.addStaticMethod (" void" , " build" , std::move ( paramList) );
2955
+ auto *m = opClass.addStaticMethod (" void" , " build" , paramList);
2910
2956
// If the builder is redundant, skip generating the method
2911
2957
if (!m)
2912
2958
return ;
2959
+ genInlineCreateBody (paramList);
2913
2960
auto &body = m->body ();
2914
2961
genCodeForAddingArgAndRegionForBuilder (body, inferredAttributes,
2915
2962
/* isRawValueAttr=*/ attrType ==
@@ -2948,10 +2995,11 @@ void OpEmitter::genUseAttrAsResultTypeCollectiveParamBuilder(
2948
2995
: " attributes" ;
2949
2996
paramList.emplace_back (" ::llvm::ArrayRef<::mlir::NamedAttribute>" ,
2950
2997
attributesName, " {}" );
2951
- auto *m = opClass.addStaticMethod (" void" , " build" , std::move ( paramList) );
2998
+ auto *m = opClass.addStaticMethod (" void" , " build" , paramList);
2952
2999
// If the builder is redundant, skip generating the method
2953
3000
if (!m)
2954
3001
return ;
3002
+ genInlineCreateBody (paramList);
2955
3003
2956
3004
auto &body = m->body ();
2957
3005
@@ -3114,10 +3162,11 @@ void OpEmitter::genCollectiveParamBuilder(CollectiveBuilderKind kind) {
3114
3162
if (op.getNumVariadicRegions ())
3115
3163
paramList.emplace_back (" unsigned" , " numRegions" );
3116
3164
3117
- auto *m = opClass.addStaticMethod (" void" , " build" , std::move ( paramList) );
3165
+ auto *m = opClass.addStaticMethod (" void" , " build" , paramList);
3118
3166
// If the builder is redundant, skip generating the method
3119
3167
if (!m)
3120
3168
return ;
3169
+ genInlineCreateBody (paramList);
3121
3170
auto &body = m->body ();
3122
3171
3123
3172
// Operands
0 commit comments