Skip to content

[MLIR][DLTI] Make DLTI queries visit all ancestors/respect nested scopes #115043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/docs/Dialects/Transform.md
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,10 @@ ops rather than having the methods directly act on the payload IR.

[include "Dialects/DebugExtensionOps.md"]

## DLTI Transform Operations

[include "Dialects/DLTITransformOps.md"]

## IRDL (extension) Transform Operations

[include "Dialects/IRDLExtensionOps.md"]
Expand Down
5 changes: 3 additions & 2 deletions mlir/include/mlir/Dialect/DLTI/DLTI.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ class DataLayoutEntryAttrStorage;
} // namespace mlir
namespace mlir {
namespace dlti {
/// Perform a DLTI-query at `op`, recursively querying each key of `keys` on
/// query interface-implementing attrs, starting from attr obtained from `op`.
/// Perform a DLTI-query at `op`, by recursively querying each key of `keys` on
/// `DLTIQueryInterface`-implementing attributes of an op, attempting this query
/// procedure for all ancestors of `op` in turn, starting with `op` itself.
FailureOr<Attribute> query(Operation *op, ArrayRef<DataLayoutEntryKey> keys,
bool emitError = false);
} // namespace dlti
Expand Down
33 changes: 24 additions & 9 deletions mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,40 @@ def QueryOp : Op<Transform_Dialect, "dlti.query", [
This op queries data layout and target information associated to payload
IR by way of the DLTI dialect.

A lookup is performed for the given `keys` at `target` op - or its closest
interface-implementing ancestor - by way of the `DLTIQueryInterface`, which
returns an attribute for a key. Each key should be either a (quoted) string
or a type. If more than one key is provided, the lookup continues
recursively, now on the returned attributes, with the condition that these
implement the above interface. For example if the payload IR is
A lookup is performed with respect to `keys`, first at the `target` op and
subsequently at its ancestors. At each op `DLTIQueryInterface`-implementing
attributes are recursively queried, one key per query. Each key should be
either a (quoted) string or a type. If more than one key is provided, the
lookup continues recursively, now on the returned attributes, with the
condition that these implement the above interface.

For example if the payload IR is

```
module attributes {#dlti.map = #dlti.map<#dlti.dl_entry<"A",
#dlti.map<#dlti.dl_entry<"B", 42: int>>>} {
module attributes {#dlti.map = #dlti.map<"A" = #dlti.map<"B" = 42>>} {
func.func private @f()
}
```
and we have that `%func` is a Tranform handle to op `@f`, then
and we have that `%func` is a Tranform handle to func `@f`, then
`transform.dlti.query ["A", "B"] at %func` returns 42 as a param and
`transform.dlti.query ["A"] at %func` returns the `#dlti.map` attribute
containing just the key "B" and its value. Using `["B"]` or `["A","C"]` as
`keys` will yield an error.

In the below example we have that querying `["A", "B"]` at `%func` - or any
op that it contains - returns 0, while these same keys yield 42 if the
containing module would be the `target` of the query. ```
```
module attributes {#dlti.map = #dlti.map<"A" = #dlti.map<"B" = 42,
"C" = 1>>} {
func.func private @f() attributes {#dlti.map = #dlti.map<"A" =
#dlti.map<"B" = 0>>} {
...
}
}
```
Querying `["A","C"]` on the module, the func, or any contained op yields 1.

#### Return modes

When successful, the result, `associated_attr`, associates one attribute as
Expand Down
120 changes: 58 additions & 62 deletions mlir/lib/Dialect/DLTI/DLTI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,13 @@

#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"

#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"

using namespace mlir;

Expand Down Expand Up @@ -489,77 +485,77 @@ void TargetSystemSpecAttr::print(AsmPrinter &printer) const {
// DLTIDialect
//===----------------------------------------------------------------------===//

/// Retrieve the first `DLTIQueryInterface`-implementing attribute that is
/// attached to `op` or such an attr on as close as possible an ancestor. The
/// op the attribute is attached to is returned as well.
static std::pair<DLTIQueryInterface, Operation *>
getClosestQueryable(Operation *op) {
DLTIQueryInterface queryable = {};

// Search op and its ancestors for the first attached DLTIQueryInterface attr.
do {
for (NamedAttribute attr : op->getAttrs())
if ((queryable = dyn_cast<DLTIQueryInterface>(attr.getValue())))
break;
} while (!queryable && (op = op->getParentOp()));

return std::pair(queryable, op);
}

FailureOr<Attribute>
dlti::query(Operation *op, ArrayRef<DataLayoutEntryKey> keys, bool emitError) {
InFlightDiagnostic diag = op->emitError() << "target op of failed DLTI query";

if (keys.empty()) {
if (emitError) {
auto diag = op->emitError() << "target op of failed DLTI query";
if (emitError)
diag.attachNote(op->getLoc()) << "no keys provided to attempt query with";
}
else
diag.abandon();
return failure();
}

auto [queryable, queryOp] = getClosestQueryable(op);
Operation *reportOp = (queryOp ? queryOp : op);

if (!queryable) {
if (emitError) {
auto diag = op->emitError() << "target op of failed DLTI query";
diag.attachNote(reportOp->getLoc())
<< "no DLTI-queryable attrs on target op or any of its ancestors";
}
return failure();
}

Attribute currentAttr = queryable;
for (auto &&[idx, key] : llvm::enumerate(keys)) {
if (auto map = dyn_cast<DLTIQueryInterface>(currentAttr)) {
auto maybeAttr = map.query(key);
if (failed(maybeAttr)) {
if (emitError) {
auto diag = op->emitError() << "target op of failed DLTI query";
diag.attachNote(reportOp->getLoc())
<< "key " << keyToStr(key)
<< " has no DLTI-mapping per attr: " << map;
auto interleaveComma = [](ArrayRef<DataLayoutEntryKey> keys) {
std::string buf;
llvm::interleave(
keys, [&](auto key) { buf += keyToStr(key); }, [&]() { buf += ","; });
return buf;
};

// Recursively replace `currentAttr` by the attribute obtained by querying a
// new key on each new `currentAttr` until all `keys` have been exhausted -
// `atOp` is only used for error reporting.
auto queryKeysOnAttribute = [&](Attribute currentAttr,
Operation *atOp) -> FailureOr<Attribute> {
for (auto &&[idx, key] : llvm::enumerate(keys)) {
if (auto map = dyn_cast<DLTIQueryInterface>(currentAttr)) {
auto maybeAttr = map.query(key);
if (failed(maybeAttr)) {
if (emitError)
diag.attachNote(atOp->getLoc())
<< "key not present - failed at keys: ["
<< interleaveComma(keys.take_front(idx + 1)) << "]";
return failure();
}
currentAttr = *maybeAttr;
} else {
// The previous key, if any, is responsible for the current currentAttr.
if (idx > 0 && emitError)
diag.attachNote(atOp->getLoc())
<< "attribute at keys [" << interleaveComma(keys.take_front(idx))
<< "] is not queryable";
return failure();
}
currentAttr = *maybeAttr;
} else {
if (emitError) {
std::string commaSeparatedKeys;
llvm::interleave(
keys.take_front(idx), // All prior keys.
[&](auto key) { commaSeparatedKeys += keyToStr(key); },
[&]() { commaSeparatedKeys += ","; });

auto diag = op->emitError() << "target op of failed DLTI query";
diag.attachNote(reportOp->getLoc())
<< "got non-DLTI-queryable attribute upon looking up keys ["
<< commaSeparatedKeys << "] at op";
}
return failure();
}
return currentAttr;
};

// Run over all ancestors of `op`, starting the recursive attribute query for
// each ancestor which has an attribute on which we can perform a query.
for (Operation *ancestor = op; ancestor; ancestor = ancestor->getParentOp()) {
DLTIQueryInterface queryableAttr;
// NB: only the op's first DLTI attr will be inspected
for (NamedAttribute attr : ancestor->getAttrs())
if (auto queryableAttr = dyn_cast<DLTIQueryInterface>(attr.getValue())) {
auto maybeAttr = queryKeysOnAttribute(queryableAttr, ancestor);
if (succeeded(maybeAttr)) {
diag.abandon();
return maybeAttr;
}
}
}

if (emitError) {
if (diag.getUnderlyingDiagnostic()->getNotes().empty())
diag.attachNote(op->getLoc())
<< "no DLTI-queryable attrs on target op or any of its ancestors";
} else {
diag.abandon();
}

return currentAttr;
return failure();
}

constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutAttrName;
Expand Down
Loading
Loading