@@ -242,65 +242,6 @@ static bool hasTensorSemantics(Operation *op) {
242
242
return hasTensorResult || hasTensorOperand;
243
243
}
244
244
245
- // / Rewrite pattern that bufferizes bufferizable ops.
246
- struct BufferizationPattern
247
- : public OpInterfaceRewritePattern<BufferizableOpInterface> {
248
- BufferizationPattern (MLIRContext *context, BufferizationState &state,
249
- PatternBenefit benefit = 1 )
250
- : OpInterfaceRewritePattern<BufferizableOpInterface>(context, benefit),
251
- state (&state) {}
252
-
253
- LogicalResult matchAndRewrite (BufferizableOpInterface bufferizableOp,
254
- PatternRewriter &rewriter) const override {
255
- const BufferizationOptions &options = state->getOptions ();
256
-
257
- // No tensors => no buffers.
258
- if (!hasTensorSemantics (bufferizableOp.getOperation ()))
259
- return failure ();
260
- if (!options.isOpAllowed (bufferizableOp.getOperation ()))
261
- return failure ();
262
- return bufferizableOp.bufferize (rewriter, *state);
263
- }
264
-
265
- private:
266
- BufferizationState *const state;
267
- };
268
-
269
- // / Check the result of bufferization. Return an error if an op was not
270
- // / bufferized, unless partial bufferization is allowed.
271
- static LogicalResult
272
- checkBufferizationResult (Operation *op, const BufferizationOptions &options) {
273
- if (!options.allowUnknownOps ) {
274
- // Check if all ops were bufferized.
275
- LogicalResult status = success ();
276
- op->walk ([&](Operation *op) {
277
- if (!hasTensorSemantics (op))
278
- return WalkResult::advance ();
279
-
280
- // Bufferization dialect ops will canonicalize away if all other ops are
281
- // bufferized.
282
- if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp>(op))
283
- return WalkResult::advance ();
284
-
285
- // Ops that are not in the allow list can be ignored.
286
- if (!options.isOpAllowed (op))
287
- return WalkResult::advance ();
288
-
289
- // Ops without any uses and no side effects will fold away.
290
- if (op->getUses ().empty () && MemoryEffectOpInterface::hasNoEffect (op))
291
- return WalkResult::advance ();
292
-
293
- status = op->emitError (" op was not bufferized" );
294
- return WalkResult::interrupt ();
295
- });
296
-
297
- if (failed (status))
298
- return status;
299
- }
300
-
301
- return success ();
302
- }
303
-
304
245
LogicalResult
305
246
bufferization::finalizeBuffers (Operation *op,
306
247
const BufferizationOptions &options) {
@@ -335,35 +276,131 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
335
276
return success ();
336
277
}
337
278
279
+ namespace {
280
+ // / A rewriter that keeps track of extra information during bufferization.
281
+ class BufferizationRewriter : public IRRewriter {
282
+ public:
283
+ BufferizationRewriter (MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
284
+ DenseSet<Operation *> &toMemrefOps,
285
+ SmallVector<Operation *> &worklist)
286
+ : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
287
+ worklist (worklist) {}
288
+
289
+ protected:
290
+ void notifyOperationRemoved (Operation *op) override {
291
+ IRRewriter::notifyOperationRemoved (op);
292
+ erasedOps.insert (op);
293
+ }
294
+
295
+ void notifyOperationInserted (Operation *op) override {
296
+ IRRewriter::notifyOperationInserted (op);
297
+
298
+ // Keep track of to_memref ops.
299
+ if (isa<ToMemrefOp>(op)) {
300
+ toMemrefOps.insert (op);
301
+ return ;
302
+ }
303
+
304
+ // Skip to_tensor ops.
305
+ if (isa<ToTensorOp>(op))
306
+ return ;
307
+
308
+ // A new bufferizable op was inserted. Add it to the worklist.
309
+ if (hasTensorSemantics (op))
310
+ worklist.push_back (op);
311
+ }
312
+
313
+ private:
314
+ // / A set of all erased ops.
315
+ DenseSet<Operation *> &erasedOps;
316
+
317
+ // / A set of all to_memref ops.
318
+ DenseSet<Operation *> &toMemrefOps;
319
+
320
+ // / The list of bufferizable ops.
321
+ SmallVector<Operation *> &worklist;
322
+ };
323
+ } // namespace
324
+
338
325
LogicalResult
339
326
bufferization::bufferizeOp (Operation *op,
340
327
BufferizationState &bufferizationState) {
341
- // Bufferize the op and its nested ops.
342
- RewritePatternSet patterns (op->getContext ());
343
- patterns.add <BufferizationPattern>(patterns.getContext (), bufferizationState);
344
-
345
- // Bufferize ops top-to-bottom. When creating a new op, we should ideally
346
- // know the exact memref type of all operands. Otherwise, we have to use a
347
- // memref type with a fully dynamic layout map, which has to canonicalize
348
- // away. This is less efficient.
328
+ const auto &options = bufferizationState.getOptions ();
329
+
330
+ // Keep track of to_memref ops.
331
+ DenseSet<Operation *> toMemrefOps;
332
+ op->walk ([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert (toMemrefOp); });
333
+
334
+ // Gather all bufferizable ops in top-to-bottom order.
349
335
//
350
- // Note: If "fullyDynamicLayoutMaps = false", we may have to insert buffer
351
- // copies to fold ("finalize") to_memref(to_tensor(x)) ops with non-cast-
352
- // compatible layout maps when doing a traversal other than top-to-bottom.
353
- // There are currently no canonicalization patterns to fold these away .
354
- GreedyRewriteConfig config;
355
- config. useTopDownTraversal = true ;
356
-
357
- // TODO: Perform a preorder walk instead of the greedy pattern rewriter. This
358
- // would be more efficient because every bufferization pattern is guaranteed
359
- // to apply only a single time (otherwise, an assertion would be triggered).
360
- // However, there are restrictions wrt. erasing ops during a preorder walk,
361
- // which would likely require a larger refactoring.
362
- if ( failed ( applyPatternsAndFoldGreedily (op, std::move (patterns), config)))
363
- return failure ( );
336
+ // We should ideally know the exact memref type of all operands when
337
+ // bufferizing an op. (This is the case when bufferizing top-to-bottom.)
338
+ // Otherwise, we have to use a memref type with a fully dynamic layout map,
339
+ // which has to canonicalize away. This is less efficient .
340
+ //
341
+ // If "fullyDynamicLayoutMaps = false", we would have to insert buffer copies
342
+ // to fold ("finalize") to_memref(to_tensor(x)) ops with non-cast-compatible
343
+ // layout maps when doing a traversal other than top-to-bottom. These would
344
+ // not easily fold away.
345
+ SmallVector<Operation *> worklist;
346
+ op-> walk <WalkOrder::PreOrder>([&](Operation *op) {
347
+ if ( hasTensorSemantics (op))
348
+ worklist. push_back (op);
349
+ } );
364
350
365
- if (failed (checkBufferizationResult (op, bufferizationState.getOptions ())))
366
- return failure ();
351
+ // Keep track of all erased ops.
352
+ DenseSet<Operation *> erasedOps;
353
+
354
+ // Bufferize all ops.
355
+ BufferizationRewriter rewriter (op->getContext (), erasedOps, toMemrefOps,
356
+ worklist);
357
+ for (unsigned i = 0 ; i < worklist.size (); ++i) {
358
+ Operation *op = worklist[i];
359
+ // Skip ops that were erased.
360
+ if (erasedOps.contains (op))
361
+ continue ;
362
+ // Skip ops that are not bufferizable.
363
+ auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
364
+ if (!bufferizableOp)
365
+ continue ;
366
+ // Continue ops that are not allowed.
367
+ if (!options.isOpAllowed (op))
368
+ continue ;
369
+ // Bufferize the op.
370
+ rewriter.setInsertionPoint (op);
371
+ (void )bufferizableOp.bufferize (rewriter, bufferizationState);
372
+ }
373
+
374
+ // Fold all to_memref(to_tensor(x)) pairs.
375
+ for (Operation *op : toMemrefOps) {
376
+ if (erasedOps.contains (op))
377
+ continue ;
378
+ rewriter.setInsertionPoint (op);
379
+ (void )bufferization::foldToMemrefToTensorPair (rewriter,
380
+ cast<ToMemrefOp>(op));
381
+ }
382
+
383
+ // / Check the result of bufferization. Return an error if an op was not
384
+ // / bufferized, unless partial bufferization is allowed.
385
+ if (bufferizationState.getOptions ().allowUnknownOps )
386
+ return success ();
387
+
388
+ for (Operation *op : worklist) {
389
+ // Skip ops that are entirely gone.
390
+ if (erasedOps.contains (op))
391
+ continue ;
392
+ // Ops that no longer have tensor semantics (because they were updated
393
+ // in-place) are allowed.
394
+ if (!hasTensorSemantics (op))
395
+ continue ;
396
+ // Continue ops that are not allowed.
397
+ if (!options.isOpAllowed (op))
398
+ continue ;
399
+ // Ops without any uses and no side effects will fold away.
400
+ if (op->getUses ().empty () && MemoryEffectOpInterface::hasNoEffect (op))
401
+ continue ;
402
+ return op->emitError (" op was not bufferized" );
403
+ }
367
404
368
405
return success ();
369
406
}
0 commit comments