Skip to content

Commit 6a99423

Browse files
committed
[mlir] Expand prefixing to OpFormatGen
Follow up to also use the prefixed emitters in OpFormatGen (moved getGetterName(s) and getSetterName(s) to Operator as that is most convenient usage wise even though it just depends on Dialect). Prefix accessors in Test dialect and follow up on missed changes in OpDefinitionsGen. Differential Revision: https://reviews.llvm.org/D112118
1 parent 89950ad commit 6a99423

File tree

8 files changed

+214
-186
lines changed

8 files changed

+214
-186
lines changed

mlir/include/mlir/TableGen/Operator.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,17 @@ class Operator {
294294
// Returns the builders of this operation.
295295
ArrayRef<Builder> getBuilders() const { return builders; }
296296

297+
// Returns the preferred getter name for the accessor.
298+
std::string getGetterName(StringRef name) const {
299+
return getGetterNames(name).front();
300+
}
301+
302+
// Returns the getter names for the accessor.
303+
SmallVector<std::string, 2> getGetterNames(StringRef name) const;
304+
305+
// Returns the setter names for the accessor.
306+
SmallVector<std::string, 2> getSetterNames(StringRef name) const;
307+
297308
private:
298309
// Populates the vectors containing operands, attributes, results and traits.
299310
void populateOpStructure();

mlir/lib/TableGen/Operator.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/ADT/StringExtras.h"
2222
#include "llvm/ADT/TypeSwitch.h"
2323
#include "llvm/Support/Debug.h"
24+
#include "llvm/Support/ErrorHandling.h"
2425
#include "llvm/Support/FormatVariadic.h"
2526
#include "llvm/TableGen/Error.h"
2627
#include "llvm/TableGen/Record.h"
@@ -642,3 +643,57 @@ auto Operator::getArgToOperandOrAttribute(int index) const
642643
-> OperandOrAttribute {
643644
return attrOrOperandMapping[index];
644645
}
646+
647+
// Helper to return the names for accessor.
648+
static SmallVector<std::string, 2>
649+
getGetterOrSetterNames(bool isGetter, const Operator &op, StringRef name) {
650+
Dialect::EmitPrefix prefixType = op.getDialect().getEmitAccessorPrefix();
651+
std::string prefix;
652+
if (prefixType != Dialect::EmitPrefix::Raw)
653+
prefix = isGetter ? "get" : "set";
654+
655+
SmallVector<std::string, 2> names;
656+
bool rawToo = prefixType == Dialect::EmitPrefix::Both;
657+
658+
auto skip = [&](StringRef newName) {
659+
bool shouldSkip = newName == "getOperands";
660+
if (!shouldSkip)
661+
return false;
662+
663+
// This note could be avoided where the final function generated would
664+
// have been identical. But preferably in the op definition avoiding using
665+
// the generic name and then getting a more specialize type is better.
666+
PrintNote(op.getLoc(),
667+
"Skipping generation of prefixed accessor `" + newName +
668+
"` as it overlaps with default one; generating raw form (`" +
669+
name + "`) still");
670+
return true;
671+
};
672+
673+
if (!prefix.empty()) {
674+
names.push_back(
675+
prefix + convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true));
676+
// Skip cases which would overlap with default ones for now.
677+
if (skip(names.back())) {
678+
rawToo = true;
679+
names.clear();
680+
} else {
681+
LLVM_DEBUG(llvm::errs() << "WITH_GETTER(\"" << op.getQualCppClassName()
682+
<< "::" << names.back() << "\");\n"
683+
<< "WITH_GETTER(\"" << op.getQualCppClassName()
684+
<< "Adaptor::" << names.back() << "\");\n";);
685+
}
686+
}
687+
688+
if (prefix.empty() || rawToo)
689+
names.push_back(name.str());
690+
return names;
691+
}
692+
693+
SmallVector<std::string, 2> Operator::getGetterNames(StringRef name) const {
694+
return getGetterOrSetterNames(/*isGetter=*/true, *this, name);
695+
}
696+
697+
SmallVector<std::string, 2> Operator::getSetterNames(StringRef name) const {
698+
return getGetterOrSetterNames(/*isGetter=*/false, *this, name);
699+
}

mlir/test/lib/Dialect/Test/TestDialect.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ TestDialect::getOperationPrinter(Operation *op) const {
339339
Optional<MutableOperandRange>
340340
TestBranchOp::getMutableSuccessorOperands(unsigned index) {
341341
assert(index == 0 && "invalid successor index");
342-
return targetOperandsMutable();
342+
return getTargetOperandsMutable();
343343
}
344344

345345
//===----------------------------------------------------------------------===//
@@ -369,7 +369,7 @@ struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
369369

370370
LogicalResult matchAndRewrite(FoldToCallOp op,
371371
PatternRewriter &rewriter) const override {
372-
rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
372+
rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.getCalleeAttr(),
373373
ValueRange());
374374
return success();
375375
}
@@ -597,8 +597,8 @@ static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
597597
static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
598598
p << "test.isolated_region ";
599599
p.printOperand(op.getOperand());
600-
p.shadowRegionArgs(op.region(), op.getOperand());
601-
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
600+
p.shadowRegionArgs(op.getRegion(), op.getOperand());
601+
p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
602602
}
603603

604604
//===----------------------------------------------------------------------===//
@@ -622,7 +622,7 @@ static ParseResult parseGraphRegionOp(OpAsmParser &parser,
622622

623623
static void print(OpAsmPrinter &p, GraphRegionOp op) {
624624
p << "test.graph_region ";
625-
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
625+
p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
626626
}
627627

628628
RegionKind GraphRegionOp::getRegionKind(unsigned index) {
@@ -642,7 +642,7 @@ static ParseResult parseAffineScopeOp(OpAsmParser &parser,
642642

643643
static void print(OpAsmPrinter &p, AffineScopeOp op) {
644644
p << "test.affine_scope ";
645-
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
645+
p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
646646
}
647647

648648
//===----------------------------------------------------------------------===//
@@ -678,7 +678,7 @@ static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
678678
}
679679

680680
static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
681-
p << " " << op.keyword();
681+
p << " " << op.getKeyword();
682682
}
683683

684684
//===----------------------------------------------------------------------===//
@@ -717,7 +717,7 @@ static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
717717

718718
static void print(OpAsmPrinter &p, WrappingRegionOp op) {
719719
p << " wraps ";
720-
p.printGenericOp(&op.region().front().front());
720+
p.printGenericOp(&op.getRegion().front().front());
721721
}
722722

723723
//===----------------------------------------------------------------------===//
@@ -762,7 +762,7 @@ void TestOpWithRegionPattern::getCanonicalizationPatterns(
762762
}
763763

764764
OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
765-
return operand();
765+
return getOperand();
766766
}
767767

768768
OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
@@ -971,15 +971,15 @@ static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
971971
// Note that we only need to print the "name" attribute if the asmprinter
972972
// result name disagrees with it. This can happen in strange cases, e.g.
973973
// when there are conflicts.
974-
bool namesDisagree = op.names().size() != op.getNumResults();
974+
bool namesDisagree = op.getNames().size() != op.getNumResults();
975975

976976
SmallString<32> resultNameStr;
977977
for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
978978
resultNameStr.clear();
979979
llvm::raw_svector_ostream tmpStream(resultNameStr);
980980
p.printOperand(op.getResult(i), tmpStream);
981981

982-
auto expectedName = op.names()[i].dyn_cast<StringAttr>();
982+
auto expectedName = op.getNames()[i].dyn_cast<StringAttr>();
983983
if (!expectedName ||
984984
tmpStream.str().drop_front() != expectedName.getValue()) {
985985
namesDisagree = true;
@@ -997,7 +997,7 @@ static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
997997
void StringAttrPrettyNameOp::getAsmResultNames(
998998
function_ref<void(Value, StringRef)> setNameFn) {
999999

1000-
auto value = names();
1000+
auto value = getNames();
10011001
for (size_t i = 0, e = value.size(); i != e; ++i)
10021002
if (auto str = value[i].dyn_cast<StringAttr>())
10031003
if (!str.getValue().empty())
@@ -1014,15 +1014,15 @@ static void print(OpAsmPrinter &p, RegionIfOp op) {
10141014
p << ": " << op.getOperandTypes();
10151015
p.printArrowTypeList(op.getResultTypes());
10161016
p << " then";
1017-
p.printRegion(op.thenRegion(),
1017+
p.printRegion(op.getThenRegion(),
10181018
/*printEntryBlockArgs=*/true,
10191019
/*printBlockTerminators=*/true);
10201020
p << " else";
1021-
p.printRegion(op.elseRegion(),
1021+
p.printRegion(op.getElseRegion(),
10221022
/*printEntryBlockArgs=*/true,
10231023
/*printBlockTerminators=*/true);
10241024
p << " join";
1025-
p.printRegion(op.joinRegion(),
1025+
p.printRegion(op.getJoinRegion(),
10261026
/*printEntryBlockArgs=*/true,
10271027
/*printBlockTerminators=*/true);
10281028
}
@@ -1064,15 +1064,15 @@ void RegionIfOp::getSuccessorRegions(
10641064
// We always branch to the join region.
10651065
if (index.hasValue()) {
10661066
if (index.getValue() < 2)
1067-
regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
1067+
regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
10681068
else
10691069
regions.push_back(RegionSuccessor(getResults()));
10701070
return;
10711071
}
10721072

10731073
// The then and else regions are the entry regions of this op.
1074-
regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
1075-
regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
1074+
regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
1075+
regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
10761076
}
10771077

10781078
//===----------------------------------------------------------------------===//

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ include "TestInterfaces.td"
2626
def Test_Dialect : Dialect {
2727
let name = "test";
2828
let cppNamespace = "::test";
29-
// Temporarily flipping to _Both (given this is test only/not intended for
30-
// general use, this won't be following the 2 week process here).
31-
let emitAccessorPrefix = kEmitAccessorPrefix_Both;
29+
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
3230
let hasCanonicalizer = 1;
3331
let hasConstantMaterializer = 1;
3432
let hasOperationAttrVerify = 1;
@@ -305,9 +303,9 @@ def RankedIntElementsAttrOp : TEST_Op<"ranked_int_elements_attr"> {
305303
def DerivedTypeAttrOp : TEST_Op<"derived_type_attr", []> {
306304
let results = (outs AnyTensor:$output);
307305
DerivedTypeAttr element_dtype =
308-
DerivedTypeAttr<"return getElementTypeOrSelf(output().getType());">;
306+
DerivedTypeAttr<"return getElementTypeOrSelf(getOutput().getType());">;
309307
DerivedAttr size = DerivedAttr<"int",
310-
"return output().getType().cast<ShapedType>().getSizeInBits();",
308+
"return getOutput().getType().cast<ShapedType>().getSizeInBits();",
311309
"$_builder.getI32IntegerAttr($_self)">;
312310
}
313311

@@ -374,13 +372,10 @@ def VariadicNoTerminatorOp : TEST_Op<"variadic_no_terminator_op",
374372

375373
def ConversionCallOp : TEST_Op<"conversion_call_op",
376374
[CallOpInterface]> {
377-
let arguments = (ins Variadic<AnyType>:$inputs, SymbolRefAttr:$callee);
375+
let arguments = (ins Variadic<AnyType>:$arg_operands, SymbolRefAttr:$callee);
378376
let results = (outs Variadic<AnyType>);
379377

380378
let extraClassDeclaration = [{
381-
/// Get the argument operands to the called function.
382-
operand_range getArgOperands() { return inputs(); }
383-
384379
/// Return the callee of this operation.
385380
::mlir::CallInterfaceCallable getCallableForCallee() {
386381
return (*this)->getAttrOfType<::mlir::SymbolRefAttr>("callee");
@@ -394,7 +389,7 @@ def FunctionalRegionOp : TEST_Op<"functional_region_op",
394389
let results = (outs FunctionType);
395390

396391
let extraClassDeclaration = [{
397-
::mlir::Region *getCallableRegion() { return &body(); }
392+
::mlir::Region *getCallableRegion() { return &getBody(); }
398393
::llvm::ArrayRef<::mlir::Type> getCallableResults() {
399394
return getType().cast<::mlir::FunctionType>().getResults();
400395
}
@@ -673,7 +668,7 @@ def AttrWithTraitOp : TEST_Op<"attr_with_trait", []> {
673668
let arguments = (ins AnyAttr:$attr);
674669

675670
let verifier = [{
676-
if (this->attr().hasTrait<AttributeTrait::TestAttrTrait>())
671+
if (this->getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
677672
return success();
678673
return this->emitError("'attr' attribute should have trait 'TestAttrTrait'");
679674
}];
@@ -2340,6 +2335,10 @@ def TestLinalgConvOp :
23402335
std::string getLibraryCallName() {
23412336
return "";
23422337
}
2338+
2339+
// To conform with interface requirement on operand naming.
2340+
mlir::ValueRange inputs() { return getInputs(); }
2341+
mlir::ValueRange outputs() { return getOutputs(); }
23432342
}];
23442343
}
23452344

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
3232
static void handleNoResultOp(PatternRewriter &rewriter,
3333
OpSymbolBindingNoResult op) {
3434
// Turn the no result op to a one-result op.
35-
rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(),
36-
op.operand());
35+
rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(),
36+
op.getOperand());
3737
}
3838

3939
static bool getFirstI32Result(Operation *op, Value &value) {
@@ -531,7 +531,7 @@ struct TestBoundedRecursiveRewrite
531531
PatternRewriter &rewriter) const final {
532532
// Decrement the depth of the op in-place.
533533
rewriter.updateRootInPlace(op, [&] {
534-
op->setAttr("depth", rewriter.getI64IntegerAttr(op.depth() - 1));
534+
op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1));
535535
});
536536
return success();
537537
}
@@ -705,7 +705,7 @@ struct TestLegalizePatternDriver
705705

706706
// Mark the bound recursion operation as dynamically legal.
707707
target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
708-
[](TestRecursiveRewriteOp op) { return op.depth() == 0; });
708+
[](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
709709

710710
// Handle a partial conversion.
711711
if (mode == ConversionMode::Partial) {
@@ -1026,9 +1026,9 @@ struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
10261026
LogicalResult
10271027
matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
10281028
ConversionPatternRewriter &rewriter) const final {
1029-
Block &firstBlock = op.body().front();
1029+
Block &firstBlock = op.getBody().front();
10301030
Operation *branchOp = firstBlock.getTerminator();
1031-
Block *secondBlock = &*(std::next(op.body().begin()));
1031+
Block *secondBlock = &*(std::next(op.getBody().begin()));
10321032
auto succOperands = branchOp->getOperands();
10331033
SmallVector<Value, 2> replacements(succOperands);
10341034
rewriter.eraseOp(branchOp);
@@ -1073,7 +1073,7 @@ struct TestMergeSingleBlockOps
10731073
op->getParentOfType<SingleBlockImplicitTerminatorOp>();
10741074
if (!parentOp)
10751075
return failure();
1076-
Block &innerBlock = op.region().front();
1076+
Block &innerBlock = op.getRegion().front();
10771077
TerminatorOp innerTerminator =
10781078
cast<TerminatorOp>(innerBlock.getTerminator());
10791079
rewriter.mergeBlockBefore(&innerBlock, op);
@@ -1104,7 +1104,7 @@ struct TestMergeBlocksPatternDriver
11041104
/// Expect the op to have a single block after legalization.
11051105
target.addDynamicallyLegalOp<TestMergeBlocksOp>(
11061106
[&](TestMergeBlocksOp op) -> bool {
1107-
return llvm::hasSingleElement(op.body());
1107+
return llvm::hasSingleElement(op.getBody());
11081108
});
11091109

11101110
/// Only allow `test.br` within test.merge_blocks op.

mlir/test/lib/Transforms/TestInlining.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ struct Inliner : public PassWrapper<Inliner, FunctionPass> {
5151
// Inline the functional region operation, but only clone the internal
5252
// region if there is more than one use.
5353
if (failed(inlineRegion(
54-
interface, &callee.body(), caller, caller.getArgOperands(),
54+
interface, &callee.getBody(), caller, caller.getArgOperands(),
5555
caller.getResults(), caller.getLoc(),
5656
/*shouldCloneInlinedRegion=*/!callee.getResult().hasOneUse())))
5757
continue;

0 commit comments

Comments
 (0)