Skip to content

Commit 82140ad

Browse files
committed
[mlir] Add method to populate default attributes
Previously default attributes were only usable by way of the ODS generated accessors, but this was undesirable as 1. The ODS getters could construct Attribute each get request; 2. For non-C++ uses this would require either duplicating some of tee default attribute generating or generating additional bindings to generate methods; 3. Accessing op.getAttr("foo") and op.getFoo() would return different results; Generate method to populate default attributes that can be used to address these. This merely adds this facility but does not employ by default on any path. Differential Revision: https://reviews.llvm.org/D128962
1 parent 7ecec30 commit 82140ad

File tree

8 files changed

+105
-6
lines changed

8 files changed

+105
-6
lines changed

mlir/include/mlir/IR/ExtensibleDialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@ class DynamicOpDefinition {
431431
OperationName::PrintAssemblyFn printFn;
432432
OperationName::FoldHookFn foldHookFn;
433433
OperationName::GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
434+
OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn;
434435

435436
friend ExtensibleDialect;
436437
};

mlir/include/mlir/IR/OpDefinition.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ class OpState {
182182
static void getCanonicalizationPatterns(RewritePatternSet &results,
183183
MLIRContext *context) {}
184184

185+
/// This hook populates any unset default attrs.
186+
static void populateDefaultAttrs(const RegisteredOperationName &,
187+
NamedAttrList &) {}
188+
185189
protected:
186190
/// If the concrete type didn't implement a custom verifier hook, just fall
187191
/// back to this one which accepts everything.
@@ -1869,6 +1873,10 @@ class Op : public OpState, public Traits<ConcreteType>... {
18691873
OpState::printOpName(op, p, defaultDialect);
18701874
return cast<ConcreteType>(op).print(p);
18711875
}
1876+
/// Implementation of `PopulateDefaultAttrsFn` OperationName hook.
1877+
static OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn() {
1878+
return ConcreteType::populateDefaultAttrs;
1879+
}
18721880
/// Implementation of `VerifyInvariantsFn` OperationName hook.
18731881
static LogicalResult verifyInvariants(Operation *op) {
18741882
static_assert(hasNoDataMembers(),

mlir/include/mlir/IR/Operation.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,15 @@ class alignas(8) Operation final
467467
setAttrs(attrs.getDictionary(getContext()));
468468
}
469469

470+
/// Sets default attributes on unset attributes.
471+
void populateDefaultAttrs() {
472+
if (auto registered = getRegisteredInfo()) {
473+
NamedAttrList attrs(getAttrDictionary());
474+
registered->populateDefaultAttrs(attrs);
475+
setAttrs(attrs.getDictionary(getContext()));
476+
}
477+
}
478+
470479
//===--------------------------------------------------------------------===//
471480
// Blocks
472481
//===--------------------------------------------------------------------===//

mlir/include/mlir/IR/OperationSupport.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class Dialect;
3636
class DictionaryAttr;
3737
class ElementsAttr;
3838
class MutableOperandRangeRange;
39+
class NamedAttrList;
3940
class Operation;
4041
struct OperationState;
4142
class OpAsmParser;
@@ -69,6 +70,10 @@ class OperationName {
6970
using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
7071
using ParseAssemblyFn =
7172
llvm::unique_function<ParseResult(OpAsmParser &, OperationState &) const>;
73+
// Note: RegisteredOperationName is passed as reference here as the derived
74+
// class is defined below.
75+
using PopulateDefaultAttrsFn = llvm::unique_function<void(
76+
const RegisteredOperationName &, NamedAttrList &) const>;
7277
using PrintAssemblyFn =
7378
llvm::unique_function<void(Operation *, OpAsmPrinter &, StringRef) const>;
7479
using VerifyInvariantsFn =
@@ -112,6 +117,7 @@ class OperationName {
112117
GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
113118
HasTraitFn hasTraitFn;
114119
ParseAssemblyFn parseAssemblyFn;
120+
PopulateDefaultAttrsFn populateDefaultAttrsFn;
115121
PrintAssemblyFn printAssemblyFn;
116122
VerifyInvariantsFn verifyInvariantsFn;
117123
VerifyRegionInvariantsFn verifyRegionInvariantsFn;
@@ -254,7 +260,8 @@ class RegisteredOperationName : public OperationName {
254260
T::getParseAssemblyFn(), T::getPrintAssemblyFn(),
255261
T::getVerifyInvariantsFn(), T::getVerifyRegionInvariantsFn(),
256262
T::getFoldHookFn(), T::getGetCanonicalizationPatternsFn(),
257-
T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames());
263+
T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames(),
264+
T::getPopulateDefaultAttrsFn());
258265
}
259266
/// The use of this method is in general discouraged in favor of
260267
/// 'insert<CustomOp>(dialect)'.
@@ -266,7 +273,8 @@ class RegisteredOperationName : public OperationName {
266273
FoldHookFn &&foldHook,
267274
GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
268275
detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
269-
ArrayRef<StringRef> attrNames);
276+
ArrayRef<StringRef> attrNames,
277+
PopulateDefaultAttrsFn &&populateDefaultAttrs);
270278

271279
/// Return the dialect this operation is registered to.
272280
Dialect &getDialect() const { return *impl->dialect; }
@@ -364,6 +372,10 @@ class RegisteredOperationName : public OperationName {
364372
return impl->attributeNames;
365373
}
366374

375+
/// This hook implements the method to populate defaults attributes that are
376+
/// unset.
377+
void populateDefaultAttrs(NamedAttrList &attrs) const;
378+
367379
/// Represent the operation name as an opaque pointer. (Used to support
368380
/// PointerLikeTypeTraits).
369381
static RegisteredOperationName getFromOpaquePointer(const void *pointer) {

mlir/lib/IR/ExtensibleDialect.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,8 @@ void ExtensibleDialect::registerDynamicOp(
447447
std::move(op->printFn), std::move(op->verifyFn),
448448
std::move(op->verifyRegionFn), std::move(op->foldHookFn),
449449
std::move(op->getCanonicalizationPatternsFn),
450-
detail::InterfaceMap::get<>(), std::move(hasTraitFn), {});
450+
detail::InterfaceMap::get<>(), std::move(hasTraitFn), {},
451+
std::move(op->getPopulateDefaultAttrsFn));
451452
}
452453

453454
bool ExtensibleDialect::classof(const Dialect *dialect) {

mlir/lib/IR/MLIRContext.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -707,14 +707,19 @@ RegisteredOperationName::parseAssembly(OpAsmParser &parser,
707707
return impl->parseAssemblyFn(parser, result);
708708
}
709709

710+
void RegisteredOperationName::populateDefaultAttrs(NamedAttrList &attrs) const {
711+
impl->populateDefaultAttrsFn(*this, attrs);
712+
}
713+
710714
void RegisteredOperationName::insert(
711715
StringRef name, Dialect &dialect, TypeID typeID,
712716
ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
713717
VerifyInvariantsFn &&verifyInvariants,
714718
VerifyRegionInvariantsFn &&verifyRegionInvariants, FoldHookFn &&foldHook,
715719
GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
716720
detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
717-
ArrayRef<StringRef> attrNames) {
721+
ArrayRef<StringRef> attrNames,
722+
PopulateDefaultAttrsFn &&populateDefaultAttrs) {
718723
MLIRContext *ctx = dialect.getContext();
719724
auto &ctxImpl = ctx->getImpl();
720725
assert(ctxImpl.multiThreadedExecutionContext == 0 &&
@@ -769,6 +774,7 @@ void RegisteredOperationName::insert(
769774
impl.verifyInvariantsFn = std::move(verifyInvariants);
770775
impl.verifyRegionInvariantsFn = std::move(verifyRegionInvariants);
771776
impl.attributeNames = cachedAttrNames;
777+
impl.populateDefaultAttrsFn = std::move(populateDefaultAttrs);
772778
}
773779

774780
//===----------------------------------------------------------------------===//

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,9 @@ class OpEmitter {
430430
// Generates getters for named successors.
431431
void genNamedSuccessorGetters();
432432

433+
// Generates the method to populate default attributes.
434+
void genPopulateDefaultAttributes();
435+
433436
// Generates builder methods for the operation.
434437
void genBuilder();
435438

@@ -823,6 +826,7 @@ OpEmitter::OpEmitter(const Operator &op,
823826
genAttrSetters();
824827
genOptionalAttrRemovers();
825828
genBuilder();
829+
genPopulateDefaultAttributes();
826830
genParser();
827831
genPrinter();
828832
genVerifier();
@@ -1587,6 +1591,45 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
15871591
<< llvm::join(resultTypes, ", ") << "});\n\n";
15881592
}
15891593

1594+
void OpEmitter::genPopulateDefaultAttributes() {
1595+
// All done if no attributes have default values.
1596+
if (llvm::all_of(op.getAttributes(), [](const NamedAttribute &named) {
1597+
return !named.attr.hasDefaultValue();
1598+
}))
1599+
return;
1600+
1601+
SmallVector<MethodParameter> paramList;
1602+
paramList.emplace_back("const ::mlir::RegisteredOperationName &", "opName");
1603+
paramList.emplace_back("::mlir::NamedAttrList &", "attributes");
1604+
auto *m = opClass.addStaticMethod("void", "populateDefaultAttrs", paramList);
1605+
ERROR_IF_PRUNED(m, "populateDefaultAttrs", op);
1606+
auto &body = m->body();
1607+
body.indent();
1608+
1609+
// Set default attributes that are unset.
1610+
body << "auto attrNames = opName.getAttributeNames();\n";
1611+
body << "::mlir::Builder " << odsBuilder
1612+
<< "(attrNames.front().getContext());\n";
1613+
StringMap<int> attrIndex;
1614+
for (const auto &it : llvm::enumerate(emitHelper.getAttrMetadata())) {
1615+
attrIndex[it.value().first] = it.index();
1616+
}
1617+
for (const NamedAttribute &namedAttr : op.getAttributes()) {
1618+
auto &attr = namedAttr.attr;
1619+
if (!attr.hasDefaultValue())
1620+
continue;
1621+
auto index = attrIndex[namedAttr.name];
1622+
body << "if (!attributes.get(attrNames[" << index << "])) {\n";
1623+
FmtContext fctx;
1624+
fctx.withBuilder(odsBuilder);
1625+
std::string defaultValue = std::string(
1626+
tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
1627+
body.indent() << formatv(" attributes.append(attrNames[{0}], {1});\n",
1628+
index, defaultValue);
1629+
body.unindent() << "}\n";
1630+
}
1631+
}
1632+
15901633
void OpEmitter::genInferredTypeCollectiveParamBuilder() {
15911634
SmallVector<MethodParameter> paramList;
15921635
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
@@ -1869,7 +1912,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
18691912
auto numResults = op.getNumResults();
18701913
resultTypeNames.reserve(numResults);
18711914

1872-
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1915+
paramList.emplace_back("::mlir::OpBuilder &", odsBuilder);
18731916
paramList.emplace_back("::mlir::OperationState &", builderOpState);
18741917

18751918
switch (typeParamKind) {
@@ -2879,7 +2922,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
28792922
tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
28802923
body << " if (!attr)\n attr = " << defaultValue << ";\n";
28812924
}
2882-
body << " return attr;\n";
2925+
body << "return attr;\n";
28832926
};
28842927

28852928
{

mlir/unittests/IR/OperationSupportTest.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/IR/OperationSupport.h"
10+
#include "../../test/lib/Dialect/Test/TestDialect.h"
1011
#include "mlir/IR/Builders.h"
1112
#include "mlir/IR/BuiltinTypes.h"
1213
#include "llvm/ADT/BitVector.h"
@@ -271,4 +272,22 @@ TEST(NamedAttrListTest, TestAppendAssign) {
271272
attrs.assign({});
272273
ASSERT_TRUE(attrs.empty());
273274
}
275+
276+
TEST(OperandStorageTest, PopulateDefaultAttrs) {
277+
MLIRContext context;
278+
context.getOrLoadDialect<test::TestDialect>();
279+
Builder builder(&context);
280+
281+
OpBuilder b(&context);
282+
auto req1 = b.getI32IntegerAttr(10);
283+
auto req2 = b.getI32IntegerAttr(60);
284+
Operation *op = b.create<test::OpAttrMatch1>(b.getUnknownLoc(), req1, nullptr,
285+
nullptr, req2);
286+
EXPECT_EQ(op->getAttr("default_valued_attr"), nullptr);
287+
op->populateDefaultAttrs();
288+
auto opt = op->getAttr("default_valued_attr");
289+
EXPECT_NE(opt, nullptr) << *op;
290+
291+
op->destroy();
292+
}
274293
} // namespace

0 commit comments

Comments
 (0)