Skip to content

Commit 0eb403a

Browse files
committed
[mlir][transform] Introduce transform.sequence op
Sequence is an important transform combination primitive that just indicates transform ops being applied in a row. The simplest version requires fails immediately if any transformation in the sequence fails. Introducing this operation allows one to start placing transform IR within other IR. Depends On D123135 Reviewed By: Mogball, rriddle Differential Revision: https://reviews.llvm.org/D123664
1 parent e37726b commit 0eb403a

File tree

15 files changed

+525
-31
lines changed

15 files changed

+525
-31
lines changed
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1-
# The dialect does not have its own ops, so just generate the dialect files.
1+
# Generate the dialect files from the dialect .td.
2+
#
3+
# TODO: Make it possible to use XDialect instead of XOpsDialect in
4+
# add_mlir_dialect.
25
set(LLVM_TARGET_DEFINITIONS TransformDialect.td)
36
mlir_tablegen(TransformDialect.h.inc -gen-dialect-decls -dialect=transform)
47
mlir_tablegen(TransformDialect.cpp.inc -gen-dialect-defs -dialect=transform)
58
add_public_tablegen_target(MLIRTransformDialectIncGen)
69
add_dependencies(mlir-headers MLIRTransformDialectIncGen)
710

11+
add_mlir_dialect(TransformOps transform)
12+
813
add_mlir_interface(TransformInterfaces)

mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def Transform_Dialect : Dialect {
161161

162162
let name = "transform";
163163
let cppNamespace = "::mlir::transform";
164+
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
164165

165166
let extraClassDeclaration = [{
166167
// Make addOperations available to the TransformDialectExtension class.
@@ -172,4 +173,9 @@ def Transform_Dialect : Dialect {
172173
}];
173174
}
174175

176+
// Base class for ops that belong to the tranfsorm dialect. Ops defined in
177+
// extensions of this dialect may also use this.
178+
class TransformDialectOp<string mnemonic, list<Trait> traits = []>
179+
: Op<Transform_Dialect, mnemonic, traits>;
180+
175181
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h

Lines changed: 123 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ class TransformOpInterface;
3333
/// expected to populate the `TransformResults` class instance in order to
3434
/// update the mapping. The `applyTransform` method takes care of propagating
3535
/// the state of `TransformResults` into the instance of this class.
36+
///
37+
/// When applying transform IR operations with regions, the client is expected
38+
/// to create a RegionScope RAII object to create a new "stack frame" for
39+
/// values defined inside the region. The mappings from and to these values will
40+
/// be automatically dropped when the object goes out of scope, typically at the
41+
/// end of the "apply" function of the parent operation. If a region contains
42+
/// blocks with arguments, the client can map those arguments to payload IR ops
43+
/// using "mapBlockArguments".
3644
class TransformState {
3745
/// Mapping between a Value in the transform IR and the corresponding set of
3846
/// operations in the payload IR.
@@ -42,9 +50,19 @@ class TransformState {
4250
/// currently associated with.
4351
using TransformOpReverseMapping = DenseMap<Operation *, Value>;
4452

53+
/// Bidirectional mappings between transform IR values and payload IR
54+
/// operations.
55+
struct Mappings {
56+
TransformOpMapping direct;
57+
TransformOpReverseMapping reverse;
58+
};
59+
4560
public:
46-
/// Creates a state for the transformation rooted at the given op.
47-
explicit TransformState(Operation *root);
61+
/// Creates a state for transform ops living in the given region. The parent
62+
/// operation of the region. The second argument points to the root operation
63+
/// in the payload IR beind transformed, which may or may not contain the
64+
/// region with transform ops.
65+
TransformState(Region &region, Operation *root);
4866

4967
/// Returns the op at which the transformation state is rooted. This is
5068
/// typically helpful for transformations that apply globally.
@@ -58,10 +76,96 @@ class TransformState {
5876
/// the state accordingly.
5977
LogicalResult applyTransform(TransformOpInterface transform);
6078

79+
/// Records the mapping between a block argument in the transform IR and a
80+
/// list of operations in the payload IR. The arguments must be defined in
81+
/// blocks of the currently processed transform IR region, typically after a
82+
/// region scope is defined.
83+
LogicalResult mapBlockArguments(BlockArgument argument,
84+
ArrayRef<Operation *> operations) {
85+
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
86+
assert(argument.getParentRegion() == regionStack.back() &&
87+
"mapping block arguments from a region other than the active one");
88+
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
89+
return setPayloadOps(argument, operations);
90+
}
91+
92+
// Forward declarations to support limited visibility.
93+
class RegionScope;
94+
95+
/// Creates a new region scope for the given region. The region is expected to
96+
/// be nested in the currently processed region.
97+
// Implementation note: this method is inline but implemented outside of the
98+
// class body to comply with visibility and full-declaration requirements.
99+
inline RegionScope make_region_scope(Region &region);
100+
101+
/// A RAII object maintaining a "stack frame" for a transform IR region. When
102+
/// applying a transform IR operation that contains a region, the caller is
103+
/// expected to create a RegionScope before applying the ops contained in the
104+
/// region. This ensures that the mappings between values defined in the
105+
/// transform IR region and payload IR operations are cleared when the region
106+
/// processing ends; such values cannot be accessed outside the region.
107+
class RegionScope {
108+
public:
109+
/// Forgets the mapping from or to values defined in the associated
110+
/// transform IR region.
111+
~RegionScope() {
112+
state.mappings.erase(region);
113+
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
114+
state.regionStack.pop_back();
115+
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
116+
}
117+
118+
private:
119+
/// Creates a new scope for mappings between values defined in the given
120+
/// transform IR region and payload IR operations.
121+
RegionScope(TransformState &state, Region &region)
122+
: state(state), region(&region) {
123+
auto res = state.mappings.try_emplace(this->region);
124+
assert(res.second && "the region scope is already present");
125+
(void)res;
126+
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
127+
assert(state.regionStack.back()->isProperAncestor(&region) &&
128+
"scope started at a non-nested region");
129+
state.regionStack.push_back(&region);
130+
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
131+
}
132+
133+
/// Back-reference to the transform state.
134+
TransformState &state;
135+
136+
/// The region this scope is associated with.
137+
Region *region;
138+
139+
friend RegionScope TransformState::make_region_scope(Region &);
140+
};
141+
friend class RegionScope;
142+
61143
private:
62144
/// Identifier for storing top-level value in the `operations` mapping.
63145
static constexpr Value kTopLevelValue = Value();
64146

147+
/// Returns the mappings frame for the reigon in which the value is defined.
148+
const Mappings &getMapping(Value value) const {
149+
return const_cast<TransformState *>(this)->getMapping(value);
150+
}
151+
Mappings &getMapping(Value value) {
152+
auto it = mappings.find(value.getParentRegion());
153+
assert(it != mappings.end() &&
154+
"trying to find a mapping for a value from an unmapped region");
155+
return it->second;
156+
}
157+
158+
/// Returns the mappings frame for the region in which the operation resides.
159+
const Mappings &getMapping(Operation *operation) const {
160+
return const_cast<TransformState *>(this)->getMapping(operation);
161+
}
162+
Mappings &getMapping(Operation *operation) {
163+
auto it = mappings.find(operation->getParentRegion());
164+
assert(it != mappings.end() &&
165+
"trying to find a mapping for an operation from an unmapped region");
166+
return it->second;
167+
}
168+
65169
/// Sets the payload IR ops associated with the given transform IR value.
66170
/// Fails if this would result in multiple transform IR values with uses
67171
/// corresponding to the same payload IR ops. For example, a hypothetical
@@ -88,9 +192,19 @@ class TransformState {
88192
void updatePayloadOps(Value value,
89193
function_ref<Operation *(Operation *)> callback);
90194

91-
/// The mapping between payload IR values and transform IR ops.
92-
TransformOpMapping operationMapping;
93-
TransformOpReverseMapping reverseMapping;
195+
/// The mappings between transform IR values and payload IR ops, aggregated by
196+
/// the region in which the transform IR values are defined.
197+
llvm::SmallDenseMap<Region *, Mappings> mappings;
198+
199+
/// The top-level operation that contains all payload IR, typically a module.
200+
Operation *topLevel;
201+
202+
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
203+
/// A stack of nested regions that are being processed in the transform IR.
204+
/// Each region must be an ancestor of the following regions in this list.
205+
/// These are also the keys for "mappings".
206+
SmallVector<Region *> regionStack;
207+
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
94208
};
95209

96210
/// Local mapping between values defined by a specific op implementing the
@@ -123,6 +237,10 @@ class TransformResults {
123237
SmallVector<Operation *> operations;
124238
};
125239

240+
TransformState::RegionScope TransformState::make_region_scope(Region &region) {
241+
return RegionScope(*this, region);
242+
}
243+
126244
} // namespace transform
127245
} // namespace mlir
128246

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===- TransformDialect.h - Transform dialect operations --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H
10+
#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H
11+
12+
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
13+
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
14+
#include "mlir/IR/OpDefinition.h"
15+
#include "mlir/IR/OpImplementation.h"
16+
17+
#define GET_OP_CLASSES
18+
#include "mlir/Dialect/Transform/IR/TransformOps.h.inc"
19+
20+
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//===- TransformOps.td - Transform dialect operations ------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS
10+
#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS
11+
12+
include "mlir/IR/OpAsmInterface.td"
13+
include "mlir/Dialect/PDL/IR/PDLTypes.td"
14+
include "mlir/Dialect/Transform/IR/TransformDialect.td"
15+
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
16+
17+
def SequenceOp : TransformDialectOp<"sequence",
18+
[DeclareOpInterfaceMethods<TransformOpInterface>, OpAsmOpInterface,
19+
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> {
20+
let summary = "Contains a sequence of other transform ops to apply";
21+
let description = [{
22+
The transformations indicated by the sequence are applied in order of their
23+
appearance. Each value produced by a transformation within the sequence
24+
corresponds to an operation or a group of operations in the payload IR.
25+
Each value may be used at most once by another transformation operation as
26+
the transformation is likely to replace the transformed operation with
27+
another operation or a group thereof. In such cases, the transformation
28+
operation is expected to produce a new value to denote the newly produced
29+
operations that can be transformed further. During application, if any
30+
transformation in the sequence fails, the entire sequence fails immediately
31+
leaving the payload IR in potentially invalid state, i.e., this operation
32+
offers no transformation rollback capabilities.
33+
34+
The entry block of this operation has a single argument that maps to either
35+
the operand if provided or the top-level container operation of the payload
36+
IR, typically the root operation of the pass interpreting the transform
37+
dialect. Operand omission is only allowed for sequences not contained in
38+
another sequence.
39+
}];
40+
41+
let arguments = (ins Optional<PDL_Operation>:$root);
42+
let results = (outs Variadic<AnyType>:$results);
43+
let regions = (region SizedRegion<1>:$body);
44+
45+
let assemblyFormat =
46+
"($root^)? attr-dict-with-keyword regions (`:` type($results)^)?";
47+
48+
let extraClassDeclaration = [{
49+
/// Allow the dialect prefix to be omitted.
50+
static StringRef getDefaultDialect() { return "transform"; }
51+
52+
Block *getBodyBlock() {
53+
return &getBody().front();
54+
}
55+
}];
56+
57+
let hasVerifier = 1;
58+
}
59+
60+
def YieldOp : TransformDialectOp<"yield", [Terminator]> {
61+
let summary = "Yields operation handles from a transform IR region";
62+
let description = [{
63+
This terminator operation yields operation handles from regions of the
64+
transform IR ops back to the containing op. It is not itself associated with
65+
any transformation on the payload IR and is used for flow purposes only.
66+
}];
67+
68+
let arguments = (ins Variadic<AnyType>:$operands);
69+
let assemblyFormat = "operands attr-dict (`:` type($operands)^)?";
70+
71+
let builders = [
72+
OpBuilder<(ins), [{
73+
return build($_builder, $_state, ::mlir::ValueRange());
74+
}]>
75+
];
76+
}
77+
78+
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
add_mlir_dialect_library(MLIRTransformDialect
22
TransformDialect.cpp
33
TransformInterfaces.cpp
4+
TransformOps.cpp
45

56
DEPENDS
67
MLIRTransformDialectIncGen
78
MLIRTransformInterfacesIncGen
89

910
LINK_LIBS PUBLIC
1011
MLIRIR
12+
MLIRPDL
13+
MLIRPDLInterp
1114
)

mlir/lib/Dialect/Transform/IR/TransformDialect.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,15 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
10-
11-
#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
10+
#include "mlir/Dialect/Transform/IR/TransformOps.h"
1211

1312
using namespace mlir;
1413

15-
void transform::TransformDialect::initialize() {}
14+
#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
15+
16+
void transform::TransformDialect::initialize() {
17+
addOperations<
18+
#define GET_OP_LIST
19+
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
20+
>();
21+
}

0 commit comments

Comments
 (0)