@@ -92,6 +92,130 @@ FailureOr<LinalgOp> transform::DecomposeOp::applyToOne(LinalgOp target) {
92
92
return reportUnknownTransformError (target);
93
93
}
94
94
95
+ // ===----------------------------------------------------------------------===//
96
+ // FuseOp
97
+ // ===----------------------------------------------------------------------===//
98
+
99
+ // / Apply a tiling transformation to all payload ops and store both the
100
+ // / tiled operation as well as the created tile loops.
101
+ static LogicalResult
102
+ applyTilingToAll (Operation *transformOp, Value target,
103
+ ArrayRef<int64_t > tileSizes,
104
+ transform::TransformResults &transformResults,
105
+ transform::TransformState &state,
106
+ function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
107
+ // Number of loops: Number of tiles sizes that are not zero.
108
+ size_t numLoops = tileSizes.size () - llvm::count (tileSizes, 0 );
109
+ // All payload ops. These should all be LinalgOps for now.
110
+ ArrayRef<Operation *> payloadOps = state.getPayloadOps (target);
111
+
112
+ SmallVector<Operation *> tiledLinalgOps;
113
+ SmallVector<SmallVector<Operation *>> loopOps (numLoops);
114
+ for (unsigned int i = 0 ; i < numLoops; ++i)
115
+ loopOps[i].reserve (payloadOps.size ());
116
+
117
+ for (Operation *target : payloadOps) {
118
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
119
+ if (!linalgOp)
120
+ return transformOp->emitError (" only LinalgOps are supported" );
121
+
122
+ FailureOr<TiledLinalgOp> tiled = applyFn (linalgOp);
123
+ if (failed (tiled))
124
+ return failure ();
125
+
126
+ tiledLinalgOps.push_back (tiled->op );
127
+ if (tiled->loops .size () != numLoops)
128
+ // Not enough loops were generated. This usually means that the input size
129
+ // was smaller than the tiling size.
130
+ // TODO: LinalgTilingPattern should return failure().
131
+ return failure ();
132
+ for (unsigned int i = 0 ; i < numLoops; ++i)
133
+ loopOps[i].push_back (tiled->loops [i]);
134
+ }
135
+
136
+ transformResults.set (transformOp->getOpResult (0 ), tiledLinalgOps);
137
+ for (unsigned int i = 0 ; i < numLoops; ++i)
138
+ transformResults.set (transformOp->getOpResult (i + 1 ), loopOps[i]);
139
+ return success ();
140
+ }
141
+
142
+ // / Parse a tiling-like operation that returns the tiled op as well as the
143
+ // / created tile loops. The function counts the non-zero tile sizes to compute
144
+ // / the number of results.
145
+ static ParseResult parseTileLikeOp (OpAsmParser &parser, OperationState &result,
146
+ StringRef sizesAttrName) {
147
+ OpAsmParser::UnresolvedOperand targetOperand;
148
+ SMLoc opLoc = parser.getCurrentLocation ();
149
+ if (parser.parseOperand (targetOperand) ||
150
+ parser.parseOptionalAttrDict (result.attributes ))
151
+ return failure ();
152
+ Attribute sizesAttr = result.attributes .get (sizesAttrName);
153
+ if (!sizesAttr)
154
+ return parser.emitError (opLoc)
155
+ << " expected '" << sizesAttrName << " ' attribute" ;
156
+ auto sizesArrayAttr = sizesAttr.dyn_cast <ArrayAttr>();
157
+ if (!sizesArrayAttr)
158
+ return parser.emitError (opLoc)
159
+ << " '" << sizesAttrName << " ' attribute must be an array" ;
160
+ Type pdlOpType = parser.getBuilder ().getType <pdl::OperationType>();
161
+ size_t numExpectedLoops =
162
+ sizesArrayAttr.size () - llvm::count (extractI64Array (sizesArrayAttr), 0 );
163
+ result.addTypes (SmallVector<Type>(numExpectedLoops + 1 , pdlOpType));
164
+ if (parser.resolveOperand (targetOperand, pdlOpType, result.operands ))
165
+ return failure ();
166
+ return success ();
167
+ }
168
+
169
+ LogicalResult
170
+ transform::FuseOp::apply (mlir::transform::TransformResults &transformResults,
171
+ mlir::transform::TransformState &state) {
172
+ LinalgTilingAndFusionOptions fusionOptions;
173
+ fusionOptions.tileSizes = extractI64Array (getTileSizes ());
174
+ fusionOptions.tileInterchange = extractI64Array (getTileInterchange ());
175
+
176
+ return applyTilingToAll (
177
+ getOperation (), getTarget (), fusionOptions.tileSizes , transformResults,
178
+ state, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
179
+ LinalgTileAndFuseTensorOpsPattern pattern (getContext (), fusionOptions);
180
+ SimpleRewriter rewriter (getContext ());
181
+ rewriter.setInsertionPoint (linalgOp);
182
+ FailureOr<TileLoopNest> tileLoopNest =
183
+ pattern.returningMatchAndRewrite (linalgOp, rewriter);
184
+ if (failed (tileLoopNest))
185
+ return failure ();
186
+
187
+ TiledLinalgOp tiledLinalgOp;
188
+ tiledLinalgOp.op = tileLoopNest->getRootOp ();
189
+ tiledLinalgOp.loops = {tileLoopNest->getLoopOps ().begin (),
190
+ tileLoopNest->getLoopOps ().end ()};
191
+ return tiledLinalgOp;
192
+ });
193
+ }
194
+
195
+ ParseResult transform::FuseOp::parse (OpAsmParser &parser,
196
+ OperationState &result) {
197
+ return parseTileLikeOp (
198
+ parser, result,
199
+ transform::FuseOp::getTileSizesAttrName (result.name ).getValue ());
200
+ }
201
+
202
+ void transform::FuseOp::print (OpAsmPrinter &p) {
203
+ p << ' ' ;
204
+ p << getTarget ();
205
+ p.printOptionalAttrDict ((*this )->getAttrs ());
206
+ }
207
+
208
+ LogicalResult transform::FuseOp::verify () {
209
+ SmallVector<int64_t > permutation = extractI64Array (getTileInterchange ());
210
+ auto sequence = llvm::to_vector (llvm::seq<int64_t >(0 , permutation.size ()));
211
+ if (!std::is_permutation (sequence.begin (), sequence.end (),
212
+ permutation.begin (), permutation.end ())) {
213
+ return emitOpError () << " expects interchange to be a permutation, found "
214
+ << getTileInterchange ();
215
+ }
216
+ return success ();
217
+ }
218
+
95
219
// ===----------------------------------------------------------------------===//
96
220
// GeneralizeOp
97
221
// ===----------------------------------------------------------------------===//
@@ -274,49 +398,6 @@ FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) {
274
398
// TileOp
275
399
// ===----------------------------------------------------------------------===//
276
400
277
- // / Apply a tiling transformation to all payload ops and store both the
278
- // / tiled operation as well as the created tile loops.
279
- static LogicalResult
280
- applyTilingToAll (Operation *transformOp, Value target,
281
- ArrayRef<int64_t > tileSizes,
282
- transform::TransformResults &transformResults,
283
- transform::TransformState &state,
284
- function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
285
- // Number of loops: Number of tiles sizes that are not zero.
286
- size_t numLoops = tileSizes.size () - llvm::count (tileSizes, 0 );
287
- // All payload ops. These should all be LinalgOps for now.
288
- ArrayRef<Operation *> payloadOps = state.getPayloadOps (target);
289
-
290
- SmallVector<Operation *> tiledLinalgOps;
291
- SmallVector<SmallVector<Operation *>> loopOps (numLoops);
292
- for (unsigned int i = 0 ; i < numLoops; ++i)
293
- loopOps[i].reserve (payloadOps.size ());
294
-
295
- for (Operation *target : payloadOps) {
296
- auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
297
- if (!linalgOp)
298
- return transformOp->emitError (" only LinalgOps are supported" );
299
-
300
- FailureOr<TiledLinalgOp> tiled = applyFn (linalgOp);
301
- if (failed (tiled))
302
- return failure ();
303
-
304
- tiledLinalgOps.push_back (tiled->op );
305
- if (tiled->loops .size () != numLoops)
306
- // Not enough loops were generated. This usually means that the input size
307
- // was smaller than the tiling size.
308
- // TODO: LinalgTilingPattern should return failure().
309
- return failure ();
310
- for (unsigned int i = 0 ; i < numLoops; ++i)
311
- loopOps[i].push_back (tiled->loops [i]);
312
- }
313
-
314
- transformResults.set (transformOp->getOpResult (0 ), tiledLinalgOps);
315
- for (unsigned int i = 0 ; i < numLoops; ++i)
316
- transformResults.set (transformOp->getOpResult (i + 1 ), loopOps[i]);
317
- return success ();
318
- }
319
-
320
401
LogicalResult transform::TileOp::apply (TransformResults &transformResults,
321
402
TransformState &state) {
322
403
LinalgTilingOptions tilingOptions;
@@ -337,27 +418,8 @@ LogicalResult transform::TileOp::apply(TransformResults &transformResults,
337
418
338
419
ParseResult transform::TileOp::parse (OpAsmParser &parser,
339
420
OperationState &result) {
340
- StringRef sizesAttrName = TileOp::getSizesAttrName (result.name ).getValue ();
341
- OpAsmParser::UnresolvedOperand targetOperand;
342
- SMLoc opLoc = parser.getCurrentLocation ();
343
- if (parser.parseOperand (targetOperand) ||
344
- parser.parseOptionalAttrDict (result.attributes ))
345
- return failure ();
346
- Attribute sizesAttr = result.attributes .get (sizesAttrName);
347
- if (!sizesAttr)
348
- return parser.emitError (opLoc)
349
- << " expected '" << sizesAttrName << " ' attribute" ;
350
- auto sizesArrayAttr = sizesAttr.dyn_cast <ArrayAttr>();
351
- if (!sizesArrayAttr)
352
- return parser.emitError (opLoc)
353
- << " '" << sizesAttrName << " ' attribute must be an array" ;
354
- Type pdlOpType = parser.getBuilder ().getType <pdl::OperationType>();
355
- size_t numExpectedLoops =
356
- sizesArrayAttr.size () - llvm::count (extractI64Array (sizesArrayAttr), 0 );
357
- result.addTypes (SmallVector<Type>(numExpectedLoops + 1 , pdlOpType));
358
- if (parser.resolveOperand (targetOperand, pdlOpType, result.operands ))
359
- return failure ();
360
- return success ();
421
+ return parseTileLikeOp (parser, result,
422
+ TileOp::getSizesAttrName (result.name ).getValue ());
361
423
}
362
424
363
425
void TileOp::print (OpAsmPrinter &p) {
@@ -366,26 +428,6 @@ void TileOp::print(OpAsmPrinter &p) {
366
428
p.printOptionalAttrDict ((*this )->getAttrs ());
367
429
}
368
430
369
- void TileOp::getEffects (
370
- SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
371
- &effects) {
372
- // `target` arg is consumed and can no longer be used.
373
- effects.emplace_back (MemoryEffects::Read::get (), getTarget (),
374
- TransformMappingResource::get ());
375
- effects.emplace_back (MemoryEffects::Free::get (), getTarget (),
376
- TransformMappingResource::get ());
377
-
378
- for (Value r : getResults ()) {
379
- effects.emplace_back (MemoryEffects::Write::get (), r,
380
- TransformMappingResource::get ());
381
- effects.emplace_back (MemoryEffects::Allocate::get (), r,
382
- TransformMappingResource::get ());
383
- }
384
-
385
- effects.emplace_back (MemoryEffects::Read::get (), PayloadIRResource::get ());
386
- effects.emplace_back (MemoryEffects::Write::get (), PayloadIRResource::get ());
387
- }
388
-
389
431
// ===----------------------------------------------------------------------===//
390
432
// VectorizeOp
391
433
// ===----------------------------------------------------------------------===//
0 commit comments