Skip to content

Commit d820acd

Browse files
[mlir][bufferize][NFC] Use custom walk instead of GreedyPatternRewriter
The bufferization driver was previously using a GreedyPatternRewriter. This was problematic because bufferization must traverse ops top-to-bottom. The GreedyPatternRewriter was previously configured via `useTopDownTraversal`, but this was a hack; this API was just meant for performance improvements and should not affect the result of the rewrite. BEGIN_PUBLIC No public commit message needed. END_PUBLIC Differential Revision: https://reviews.llvm.org/D123618
1 parent 9b32886 commit d820acd

File tree

5 files changed

+184
-148
lines changed

5 files changed

+184
-148
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ namespace bufferization {
4646
/// with differing element types or memory spaces.
4747
FailureOr<Value> castOrReallocMemRefValue(OpBuilder &b, Value value,
4848
MemRefType type);
49+
50+
/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
51+
/// to_memref op are different, a memref.cast is needed.
52+
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
53+
ToMemrefOp toMemref,
54+
bool allowSameType = true);
55+
4956
} // namespace bufferization
5057
} // namespace mlir
5158

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,6 @@ mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
2121
MemRefType destType) {
2222
auto srcType = value.getType().cast<MemRefType>();
2323

24-
// Casting to the same type, nothing to do.
25-
if (srcType == destType)
26-
return value;
27-
2824
// Element type, rank and memory space must match.
2925
if (srcType.getElementType() != destType.getElementType())
3026
return failure();
@@ -79,6 +75,55 @@ mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
7975
return copy;
8076
}
8177

78+
/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
79+
/// to_memref op are different, a memref.cast is needed.
80+
LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
81+
RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) {
82+
auto memrefToTensor = toMemref.tensor().getDefiningOp<ToTensorOp>();
83+
if (!memrefToTensor)
84+
return failure();
85+
86+
Type srcType = memrefToTensor.memref().getType();
87+
Type destType = toMemref.getType();
88+
89+
// Directly rewrite if the type did not change.
90+
if (srcType == destType) {
91+
// Function can be configured to only handle cases where a cast is needed.
92+
if (!allowSameType)
93+
return failure();
94+
rewriter.replaceOp(toMemref, memrefToTensor.memref());
95+
return success();
96+
}
97+
98+
auto rankedSrcType = srcType.dyn_cast<MemRefType>();
99+
auto rankedDestType = destType.dyn_cast<MemRefType>();
100+
auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>();
101+
102+
// Ranked memref -> Ranked memref cast.
103+
if (rankedSrcType && rankedDestType) {
104+
FailureOr<Value> replacement = castOrReallocMemRefValue(
105+
rewriter, memrefToTensor.memref(), rankedDestType);
106+
if (failed(replacement))
107+
return failure();
108+
109+
rewriter.replaceOp(toMemref, *replacement);
110+
return success();
111+
}
112+
113+
// Unranked memref -> Ranked memref cast: May require a copy.
114+
// TODO: Not implemented at the moment.
115+
if (unrankedSrcType && rankedDestType)
116+
return failure();
117+
118+
// Unranked memref -> unranked memref cast
119+
// Ranked memref -> unranked memref cast: No copy needed.
120+
assert(memref::CastOp::areCastCompatible(srcType, destType) &&
121+
"expected that types are cast compatible");
122+
rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
123+
memrefToTensor.memref());
124+
return success();
125+
}
126+
82127
//===----------------------------------------------------------------------===//
83128
// CloneOp
84129
//===----------------------------------------------------------------------===//
@@ -249,51 +294,6 @@ struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
249294
}
250295
};
251296

252-
/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
253-
/// to_memref op are different, a memref.cast is needed.
254-
static LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
255-
ToMemrefOp toMemref,
256-
bool allowSameType = true) {
257-
auto memrefToTensor = toMemref.tensor().getDefiningOp<ToTensorOp>();
258-
if (!memrefToTensor)
259-
return failure();
260-
261-
Type srcType = memrefToTensor.memref().getType();
262-
Type destType = toMemref.getType();
263-
264-
// Function can be configured to only handle cases where a cast is needed.
265-
if (!allowSameType && srcType == destType)
266-
return failure();
267-
268-
auto rankedSrcType = srcType.dyn_cast<MemRefType>();
269-
auto rankedDestType = destType.dyn_cast<MemRefType>();
270-
auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>();
271-
272-
// Ranked memref -> Ranked memref cast.
273-
if (rankedSrcType && rankedDestType) {
274-
FailureOr<Value> replacement = castOrReallocMemRefValue(
275-
rewriter, memrefToTensor.memref(), rankedDestType);
276-
if (failed(replacement))
277-
return failure();
278-
279-
rewriter.replaceOp(toMemref, *replacement);
280-
return success();
281-
}
282-
283-
// Unranked memref -> Ranked memref cast: May require a copy.
284-
// TODO: Not implemented at the moment.
285-
if (unrankedSrcType && rankedDestType)
286-
return failure();
287-
288-
// Unranked memref -> unranked memref cast
289-
// Ranked memref -> unranked memref cast: No copy needed.
290-
assert(memref::CastOp::areCastCompatible(srcType, destType) &&
291-
"expected that types are cast compatible");
292-
rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
293-
memrefToTensor.memref());
294-
return success();
295-
}
296-
297297
/// Canonicalize bufferization.to_tensor + bufferization.to_memref to
298298
/// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in.
299299
struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> {

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 120 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -242,65 +242,6 @@ static bool hasTensorSemantics(Operation *op) {
242242
return hasTensorResult || hasTensorOperand;
243243
}
244244

245-
/// Rewrite pattern that bufferizes bufferizable ops.
246-
struct BufferizationPattern
247-
: public OpInterfaceRewritePattern<BufferizableOpInterface> {
248-
BufferizationPattern(MLIRContext *context, BufferizationState &state,
249-
PatternBenefit benefit = 1)
250-
: OpInterfaceRewritePattern<BufferizableOpInterface>(context, benefit),
251-
state(&state) {}
252-
253-
LogicalResult matchAndRewrite(BufferizableOpInterface bufferizableOp,
254-
PatternRewriter &rewriter) const override {
255-
const BufferizationOptions &options = state->getOptions();
256-
257-
// No tensors => no buffers.
258-
if (!hasTensorSemantics(bufferizableOp.getOperation()))
259-
return failure();
260-
if (!options.isOpAllowed(bufferizableOp.getOperation()))
261-
return failure();
262-
return bufferizableOp.bufferize(rewriter, *state);
263-
}
264-
265-
private:
266-
BufferizationState *const state;
267-
};
268-
269-
/// Check the result of bufferization. Return an error if an op was not
270-
/// bufferized, unless partial bufferization is allowed.
271-
static LogicalResult
272-
checkBufferizationResult(Operation *op, const BufferizationOptions &options) {
273-
if (!options.allowUnknownOps) {
274-
// Check if all ops were bufferized.
275-
LogicalResult status = success();
276-
op->walk([&](Operation *op) {
277-
if (!hasTensorSemantics(op))
278-
return WalkResult::advance();
279-
280-
// Bufferization dialect ops will canonicalize away if all other ops are
281-
// bufferized.
282-
if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp>(op))
283-
return WalkResult::advance();
284-
285-
// Ops that are not in the allow list can be ignored.
286-
if (!options.isOpAllowed(op))
287-
return WalkResult::advance();
288-
289-
// Ops without any uses and no side effects will fold away.
290-
if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
291-
return WalkResult::advance();
292-
293-
status = op->emitError("op was not bufferized");
294-
return WalkResult::interrupt();
295-
});
296-
297-
if (failed(status))
298-
return status;
299-
}
300-
301-
return success();
302-
}
303-
304245
LogicalResult
305246
bufferization::finalizeBuffers(Operation *op,
306247
const BufferizationOptions &options) {
@@ -335,35 +276,131 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
335276
return success();
336277
}
337278

279+
namespace {
280+
/// A rewriter that keeps track of extra information during bufferization.
281+
class BufferizationRewriter : public IRRewriter {
282+
public:
283+
BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
284+
DenseSet<Operation *> &toMemrefOps,
285+
SmallVector<Operation *> &worklist)
286+
: IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
287+
worklist(worklist) {}
288+
289+
protected:
290+
void notifyOperationRemoved(Operation *op) override {
291+
IRRewriter::notifyOperationRemoved(op);
292+
erasedOps.insert(op);
293+
}
294+
295+
void notifyOperationInserted(Operation *op) override {
296+
IRRewriter::notifyOperationInserted(op);
297+
298+
// Keep track of to_memref ops.
299+
if (isa<ToMemrefOp>(op)) {
300+
toMemrefOps.insert(op);
301+
return;
302+
}
303+
304+
// Skip to_tensor ops.
305+
if (isa<ToTensorOp>(op))
306+
return;
307+
308+
// A new bufferizable op was inserted. Add it to the worklist.
309+
if (hasTensorSemantics(op))
310+
worklist.push_back(op);
311+
}
312+
313+
private:
314+
/// A set of all erased ops.
315+
DenseSet<Operation *> &erasedOps;
316+
317+
/// A set of all to_memref ops.
318+
DenseSet<Operation *> &toMemrefOps;
319+
320+
/// The list of bufferizable ops.
321+
SmallVector<Operation *> &worklist;
322+
};
323+
} // namespace
324+
338325
LogicalResult
339326
bufferization::bufferizeOp(Operation *op,
340327
BufferizationState &bufferizationState) {
341-
// Bufferize the op and its nested ops.
342-
RewritePatternSet patterns(op->getContext());
343-
patterns.add<BufferizationPattern>(patterns.getContext(), bufferizationState);
344-
345-
// Bufferize ops top-to-bottom. When creating a new op, we should ideally
346-
// know the exact memref type of all operands. Otherwise, we have to use a
347-
// memref type with a fully dynamic layout map, which has to canonicalize
348-
// away. This is less efficient.
328+
const auto &options = bufferizationState.getOptions();
329+
330+
// Keep track of to_memref ops.
331+
DenseSet<Operation *> toMemrefOps;
332+
op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); });
333+
334+
// Gather all bufferizable ops in top-to-bottom order.
349335
//
350-
// Note: If "fullyDynamicLayoutMaps = false", we may have to insert buffer
351-
// copies to fold ("finalize") to_memref(to_tensor(x)) ops with non-cast-
352-
// compatible layout maps when doing a traversal other than top-to-bottom.
353-
// There are currently no canonicalization patterns to fold these away.
354-
GreedyRewriteConfig config;
355-
config.useTopDownTraversal = true;
356-
357-
// TODO: Perform a preorder walk instead of the greedy pattern rewriter. This
358-
// would be more efficient because every bufferization pattern is guaranteed
359-
// to apply only a single time (otherwise, an assertion would be triggered).
360-
// However, there are restrictions wrt. erasing ops during a preorder walk,
361-
// which would likely require a larger refactoring.
362-
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
363-
return failure();
336+
// We should ideally know the exact memref type of all operands when
337+
// bufferizing an op. (This is the case when bufferizing top-to-bottom.)
338+
// Otherwise, we have to use a memref type with a fully dynamic layout map,
339+
// which has to canonicalize away. This is less efficient.
340+
//
341+
// If "fullyDynamicLayoutMaps = false", we would have to insert buffer copies
342+
// to fold ("finalize") to_memref(to_tensor(x)) ops with non-cast-compatible
343+
// layout maps when doing a traversal other than top-to-bottom. These would
344+
// not easily fold away.
345+
SmallVector<Operation *> worklist;
346+
op->walk<WalkOrder::PreOrder>([&](Operation *op) {
347+
if (hasTensorSemantics(op))
348+
worklist.push_back(op);
349+
});
364350

365-
if (failed(checkBufferizationResult(op, bufferizationState.getOptions())))
366-
return failure();
351+
// Keep track of all erased ops.
352+
DenseSet<Operation *> erasedOps;
353+
354+
// Bufferize all ops.
355+
BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
356+
worklist);
357+
for (unsigned i = 0; i < worklist.size(); ++i) {
358+
Operation *op = worklist[i];
359+
// Skip ops that were erased.
360+
if (erasedOps.contains(op))
361+
continue;
362+
// Skip ops that are not bufferizable.
363+
auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
364+
if (!bufferizableOp)
365+
continue;
366+
// Continue ops that are not allowed.
367+
if (!options.isOpAllowed(op))
368+
continue;
369+
// Bufferize the op.
370+
rewriter.setInsertionPoint(op);
371+
(void)bufferizableOp.bufferize(rewriter, bufferizationState);
372+
}
373+
374+
// Fold all to_memref(to_tensor(x)) pairs.
375+
for (Operation *op : toMemrefOps) {
376+
if (erasedOps.contains(op))
377+
continue;
378+
rewriter.setInsertionPoint(op);
379+
(void)bufferization::foldToMemrefToTensorPair(rewriter,
380+
cast<ToMemrefOp>(op));
381+
}
382+
383+
/// Check the result of bufferization. Return an error if an op was not
384+
/// bufferized, unless partial bufferization is allowed.
385+
if (bufferizationState.getOptions().allowUnknownOps)
386+
return success();
387+
388+
for (Operation *op : worklist) {
389+
// Skip ops that are entirely gone.
390+
if (erasedOps.contains(op))
391+
continue;
392+
// Ops that no longer have tensor semantics (because they were updated
393+
// in-place) are allowed.
394+
if (!hasTensorSemantics(op))
395+
continue;
396+
// Continue ops that are not allowed.
397+
if (!options.isOpAllowed(op))
398+
continue;
399+
// Ops without any uses and no side effects will fold away.
400+
if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
401+
continue;
402+
return op->emitError("op was not bufferized");
403+
}
367404

368405
return success();
369406
}

mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -884,8 +884,8 @@ func.func @scf_for_yield_non_equivalent(
884884
// CHECK: %[[cloned:.*]] = bufferization.clone %[[t]]
885885
// CHECK: %[[for:.*]] = scf.for {{.*}} iter_args(%[[iter:.*]] = %[[cloned]])
886886
// This alloc is for the linalg.init_tensor.
887-
// CHECK: %[[alloc2:.*]] = memref.alloc(%{{.*}})
888-
// CHECK: memref.dealloc %[[iter]]
887+
// CHECK-DAG: %[[alloc2:.*]] = memref.alloc(%{{.*}})
888+
// CHECK-DAG: memref.dealloc %[[iter]]
889889
// This alloc is for the scf.yield.
890890
// CHECK: %[[alloc3:.*]] = memref.alloc(%{{.*}})
891891
// CHECK: memref.copy %[[alloc2]], %[[alloc3]]

0 commit comments

Comments
 (0)