@@ -33,6 +33,14 @@ class TransformOpInterface;
33
33
// / expected to populate the `TransformResults` class instance in order to
34
34
// / update the mapping. The `applyTransform` method takes care of propagating
35
35
// / the state of `TransformResults` into the instance of this class.
36
+ // /
37
+ // / When applying transform IR operations with regions, the client is expected
38
+ // / to create a RegionScope RAII object to create a new "stack frame" for
39
+ // / values defined inside the region. The mappings from and to these values will
40
+ // / be automatically dropped when the object goes out of scope, typically at the
41
+ // / end of the "apply" function of the parent operation. If a region contains
42
+ // / blocks with arguments, the client can map those arguments to payload IR ops
43
+ // / using "mapBlockArguments".
36
44
class TransformState {
37
45
// / Mapping between a Value in the transform IR and the corresponding set of
38
46
// / operations in the payload IR.
@@ -42,9 +50,19 @@ class TransformState {
42
50
// / currently associated with.
43
51
using TransformOpReverseMapping = DenseMap<Operation *, Value>;
44
52
53
+ // / Bidirectional mappings between transform IR values and payload IR
54
+ // / operations.
55
+ struct Mappings {
56
+ TransformOpMapping direct;
57
+ TransformOpReverseMapping reverse;
58
+ };
59
+
45
60
public:
46
- // / Creates a state for the transformation rooted at the given op.
47
- explicit TransformState (Operation *root);
61
+ // / Creates a state for transform ops living in the given region. The parent
62
+ // / operation of the region. The second argument points to the root operation
63
+ // / in the payload IR beind transformed, which may or may not contain the
64
+ // / region with transform ops.
65
+ TransformState (Region ®ion, Operation *root);
48
66
49
67
// / Returns the op at which the transformation state is rooted. This is
50
68
// / typically helpful for transformations that apply globally.
@@ -58,10 +76,96 @@ class TransformState {
58
76
// / the state accordingly.
59
77
LogicalResult applyTransform (TransformOpInterface transform);
60
78
79
+ // / Records the mapping between a block argument in the transform IR and a
80
+ // / list of operations in the payload IR. The arguments must be defined in
81
+ // / blocks of the currently processed transform IR region, typically after a
82
+ // / region scope is defined.
83
+ LogicalResult mapBlockArguments (BlockArgument argument,
84
+ ArrayRef<Operation *> operations) {
85
+ #if LLVM_ENABLE_ABI_BREAKING_CHECKS
86
+ assert (argument.getParentRegion () == regionStack.back () &&
87
+ " mapping block arguments from a region other than the active one" );
88
+ #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
89
+ return setPayloadOps (argument, operations);
90
+ }
91
+
92
+ // Forward declarations to support limited visibility.
93
+ class RegionScope ;
94
+
95
+ // / Creates a new region scope for the given region. The region is expected to
96
+ // / be nested in the currently processed region.
97
+ // Implementation note: this method is inline but implemented outside of the
98
+ // class body to comply with visibility and full-declaration requirements.
99
+ inline RegionScope make_region_scope (Region ®ion);
100
+
101
+ // / A RAII object maintaining a "stack frame" for a transform IR region. When
102
+ // / applying a transform IR operation that contains a region, the caller is
103
+ // / expected to create a RegionScope before applying the ops contained in the
104
+ // / region. This ensures that the mappings between values defined in the
105
+ // / transform IR region and payload IR operations are cleared when the region
106
+ // / processing ends; such values cannot be accessed outside the region.
107
+ class RegionScope {
108
+ public:
109
+ // / Forgets the mapping from or to values defined in the associated
110
+ // / transform IR region.
111
+ ~RegionScope () {
112
+ state.mappings .erase (region);
113
+ #if LLVM_ENABLE_ABI_BREAKING_CHECKS
114
+ state.regionStack .pop_back ();
115
+ #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
116
+ }
117
+
118
+ private:
119
+ // / Creates a new scope for mappings between values defined in the given
120
+ // / transform IR region and payload IR operations.
121
+ RegionScope (TransformState &state, Region ®ion)
122
+ : state(state), region(®ion) {
123
+ auto res = state.mappings .try_emplace (this ->region );
124
+ assert (res.second && " the region scope is already present" );
125
+ (void )res;
126
+ #if LLVM_ENABLE_ABI_BREAKING_CHECKS
127
+ assert (state.regionStack .back ()->isProperAncestor (®ion) &&
128
+ " scope started at a non-nested region" );
129
+ state.regionStack .push_back (®ion);
130
+ #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
131
+ }
132
+
133
+ // / Back-reference to the transform state.
134
+ TransformState &state;
135
+
136
+ // / The region this scope is associated with.
137
+ Region *region;
138
+
139
+ friend RegionScope TransformState::make_region_scope (Region &);
140
+ };
141
+ friend class RegionScope ;
142
+
61
143
private:
62
144
// / Identifier for storing top-level value in the `operations` mapping.
63
145
static constexpr Value kTopLevelValue = Value();
64
146
147
+ // / Returns the mappings frame for the reigon in which the value is defined.
148
+ const Mappings &getMapping (Value value) const {
149
+ return const_cast <TransformState *>(this )->getMapping (value);
150
+ }
151
+ Mappings &getMapping (Value value) {
152
+ auto it = mappings.find (value.getParentRegion ());
153
+ assert (it != mappings.end () &&
154
+ " trying to find a mapping for a value from an unmapped region" );
155
+ return it->second ;
156
+ }
157
+
158
+ // / Returns the mappings frame for the region in which the operation resides.
159
+ const Mappings &getMapping (Operation *operation) const {
160
+ return const_cast <TransformState *>(this )->getMapping (operation);
161
+ }
162
+ Mappings &getMapping (Operation *operation) {
163
+ auto it = mappings.find (operation->getParentRegion ());
164
+ assert (it != mappings.end () &&
165
+ " trying to find a mapping for an operation from an unmapped region" );
166
+ return it->second ;
167
+ }
168
+
65
169
// / Sets the payload IR ops associated with the given transform IR value.
66
170
// / Fails if this would result in multiple transform IR values with uses
67
171
// / corresponding to the same payload IR ops. For example, a hypothetical
@@ -88,9 +192,19 @@ class TransformState {
88
192
void updatePayloadOps (Value value,
89
193
function_ref<Operation *(Operation *)> callback);
90
194
91
- // / The mapping between payload IR values and transform IR ops.
92
- TransformOpMapping operationMapping;
93
- TransformOpReverseMapping reverseMapping;
195
+ // / The mappings between transform IR values and payload IR ops, aggregated by
196
+ // / the region in which the transform IR values are defined.
197
+ llvm::SmallDenseMap<Region *, Mappings> mappings;
198
+
199
+ // / The top-level operation that contains all payload IR, typically a module.
200
+ Operation *topLevel;
201
+
202
+ #if LLVM_ENABLE_ABI_BREAKING_CHECKS
203
+ // / A stack of nested regions that are being processed in the transform IR.
204
+ // / Each region must be an ancestor of the following regions in this list.
205
+ // / These are also the keys for "mappings".
206
+ SmallVector<Region *> regionStack;
207
+ #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
94
208
};
95
209
96
210
// / Local mapping between values defined by a specific op implementing the
@@ -123,6 +237,10 @@ class TransformResults {
123
237
SmallVector<Operation *> operations;
124
238
};
125
239
240
+ TransformState::RegionScope TransformState::make_region_scope (Region ®ion) {
241
+ return RegionScope (*this , region);
242
+ }
243
+
126
244
} // namespace transform
127
245
} // namespace mlir
128
246
0 commit comments