Skip to content

Commit 417e1c7

Browse files
[mlir][scf][bufferize][NFC] Split ForOp bufferization into smaller functions
This is in preparation of WhileOp bufferization, which reuses these functions. Differential Revision: https://reviews.llvm.org/D124933
1 parent f178c38 commit 417e1c7

File tree

1 file changed

+161
-91
lines changed

1 file changed

+161
-91
lines changed

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 161 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,150 @@ struct IfOpInterface
259259
}
260260
};
261261

262+
/// Helper function for loop bufferization. Return the indices of all values
263+
/// that have a tensor type.
264+
static DenseSet<int64_t> getTensorIndices(ValueRange values) {
265+
DenseSet<int64_t> result;
266+
for (const auto &it : llvm::enumerate(values))
267+
if (it.value().getType().isa<TensorType>())
268+
result.insert(it.index());
269+
return result;
270+
}
271+
272+
/// Helper function for loop bufferization. Return the indices of all
273+
/// bbArg/yielded value pairs who's buffer relation is "Equivalent".
274+
DenseSet<int64_t> getEquivalentBuffers(ValueRange bbArgs,
275+
ValueRange yieldedValues,
276+
const AnalysisState &state) {
277+
DenseSet<int64_t> result;
278+
int64_t counter = 0;
279+
for (const auto &it : llvm::zip(bbArgs, yieldedValues)) {
280+
if (!std::get<0>(it).getType().isa<TensorType>())
281+
continue;
282+
if (state.areEquivalentBufferizedValues(std::get<0>(it), std::get<1>(it)))
283+
result.insert(counter);
284+
counter++;
285+
}
286+
return result;
287+
}
288+
289+
/// Helper function for loop bufferization. Cast the given buffer to the given
290+
/// memref type.
291+
static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
292+
assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
293+
assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType");
294+
// If the buffer already has the correct type, no cast is needed.
295+
if (buffer.getType() == type)
296+
return buffer;
297+
// TODO: In case `type` has a layout map that is not the fully dynamic
298+
// one, we may not be able to cast the buffer. In that case, the loop
299+
// iter_arg's layout map must be changed (see uses of `castBuffer`).
300+
assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
301+
"scf.while op bufferization: cast incompatible");
302+
return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
303+
}
304+
305+
/// Helper function for loop bufferization. Return the bufferized values of the
306+
/// given OpOperands. If an operand is not a tensor, return the original value.
307+
static SmallVector<Value> getBuffers(RewriterBase &rewriter,
308+
MutableArrayRef<OpOperand> operands,
309+
BufferizationState &state) {
310+
SmallVector<Value> result;
311+
for (OpOperand &opOperand : operands) {
312+
if (opOperand.get().getType().isa<TensorType>()) {
313+
FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand);
314+
if (failed(resultBuffer))
315+
return {};
316+
result.push_back(*resultBuffer);
317+
} else {
318+
result.push_back(opOperand.get());
319+
}
320+
}
321+
return result;
322+
}
323+
324+
/// Helper function for loop bufferization. Compute the buffer that should be
325+
/// yielded from a loop block (loop body or loop condition). If the given tensor
326+
/// is equivalent to the corresponding block argument (as indicated by
327+
/// `isEquivalent`), the buffer can be yielded directly. Otherwise, a new buffer
328+
/// copy must be yielded.
329+
///
330+
/// According to the `BufferizableOpInterface` implementation of scf loops, a
331+
/// a bufferized OpResult may alias only with the corresponding bufferized
332+
/// init_arg and with no other buffers. I.e., the i-th OpResult may alias with
333+
/// the i-th init_arg; but not with any other OpOperand. If a corresponding
334+
/// OpResult/init_arg pair bufferized to equivalent buffers (as indicated by
335+
/// `isEquivalent`), this aliasing requirement is satisfied. Otherwise, we
336+
/// cannot be sure and must yield a new buffer copy. (New buffer copies do not
337+
/// alias with any buffer.)
338+
static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor,
339+
BaseMemRefType type, bool isEquivalent,
340+
BufferizationState &state) {
341+
assert(tensor.getType().isa<TensorType>() && "expected tensor");
342+
ensureToMemrefOpIsValid(tensor, type);
343+
Value yieldedVal =
344+
bufferization::lookupBuffer(rewriter, tensor, state.getOptions());
345+
346+
if (isEquivalent)
347+
// Yielded value is equivalent to the corresponding iter_arg bbArg.
348+
// Yield the value directly. Most IR should be like that. Everything
349+
// else must be resolved with copies and is potentially inefficient.
350+
// By default, such problematic IR would already have been rejected
351+
// during `verifyAnalysis`, unless `allow-return-allocs`.
352+
return castBuffer(rewriter, yieldedVal, type);
353+
354+
// It is not certain that the yielded value and the iter_arg bbArg
355+
// have the same buffer. Allocate a new buffer and copy. The yielded
356+
// buffer will get deallocated by `deallocateBuffers`.
357+
358+
// TODO: There are cases in which it is not neccessary to return a new
359+
// buffer allocation. E.g., when equivalent values are yielded in a
360+
// different order. This could be resolved with copies.
361+
Optional<Value> yieldedAlloc = state.createAlloc(
362+
rewriter, tensor.getLoc(), yieldedVal, /*deallocMemref=*/false);
363+
// TODO: We should rollback, but for now just assume that this always
364+
// succeeds.
365+
assert(yieldedAlloc.hasValue() && "could not create alloc");
366+
LogicalResult copyStatus = bufferization::createMemCpy(
367+
rewriter, tensor.getLoc(), yieldedVal, *yieldedAlloc, state.getOptions());
368+
(void)copyStatus;
369+
assert(succeeded(copyStatus) && "could not create memcpy");
370+
371+
// The iter_arg memref type may have a layout map. Cast the new buffer
372+
// to the same type if needed.
373+
return castBuffer(rewriter, *yieldedAlloc, type);
374+
}
375+
376+
/// Helper function for loop bufferization. Given a range of values, apply
377+
/// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified
378+
/// value in the result vector.
379+
static SmallVector<Value>
380+
convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
381+
llvm::function_ref<Value(Value, int64_t)> func) {
382+
SmallVector<Value> result;
383+
for (const auto &it : llvm::enumerate(values)) {
384+
size_t idx = it.index();
385+
Value val = it.value();
386+
result.push_back(tensorIndices.contains(idx) ? func(val, idx) : val);
387+
}
388+
return result;
389+
}
390+
391+
/// Helper function for loop bufferization. Given a list of pre-bufferization
392+
/// yielded values, compute the list of bufferized yielded values.
393+
SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values,
394+
TypeRange bufferizedTypes,
395+
const DenseSet<int64_t> &tensorIndices,
396+
const DenseSet<int64_t> &equivalentTensors,
397+
BufferizationState &state) {
398+
return convertTensorValues(
399+
values, tensorIndices, [&](Value val, int64_t index) {
400+
return getYieldedBuffer(rewriter, val,
401+
bufferizedTypes[index].cast<BaseMemRefType>(),
402+
equivalentTensors.contains(index), state);
403+
});
404+
}
405+
262406
/// Bufferization of scf.for. Replace with a new scf.for that operates on
263407
/// memrefs.
264408
struct ForOpInterface
@@ -312,78 +456,38 @@ struct ForOpInterface
312456
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
313457
BufferizationState &state) const {
314458
auto forOp = cast<scf::ForOp>(op);
315-
auto bufferizableOp = cast<BufferizableOpInterface>(op);
459+
auto oldYieldOp =
460+
cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
316461
Block *oldLoopBody = &forOp.getLoopBody().front();
317462

318-
// Helper function for casting MemRef buffers.
319-
auto castBuffer = [&](Value buffer, Type type) {
320-
assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
321-
assert(buffer.getType().isa<BaseMemRefType>() &&
322-
"expected BaseMemRefType");
323-
// If the buffer already has the correct type, no cast is needed.
324-
if (buffer.getType() == type)
325-
return buffer;
326-
// TODO: In case `type` has a layout map that is not the fully dynamic
327-
// one, we may not be able to cast the buffer. In that case, the loop
328-
// iter_arg's layout map must be changed (see uses of `castBuffer`).
329-
assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
330-
"scf.for op bufferization: cast incompatible");
331-
return rewriter.create<memref::CastOp>(buffer.getLoc(), type, buffer)
332-
.getResult();
333-
};
334-
335463
// Indices of all iter_args that have tensor type. These are the ones that
336464
// are bufferized.
337-
DenseSet<int64_t> indices;
465+
DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
338466
// For every yielded value, is the value equivalent to its corresponding
339467
// bbArg?
340-
SmallVector<bool> equivalentYields;
341-
for (const auto &it : llvm::enumerate(forOp.getInitArgs())) {
342-
if (it.value().getType().isa<TensorType>()) {
343-
indices.insert(it.index());
344-
BufferRelation relation = bufferizableOp.bufferRelation(
345-
forOp->getResult(it.index()), state.getAnalysisState());
346-
equivalentYields.push_back(relation == BufferRelation::Equivalent);
347-
} else {
348-
equivalentYields.push_back(false);
349-
}
350-
}
468+
DenseSet<int64_t> equivalentYields =
469+
getEquivalentBuffers(forOp.getRegionIterArgs(), oldYieldOp.getResults(),
470+
state.getAnalysisState());
351471

352-
// Given a range of values, apply `func` to those marked in `indices`.
353-
// Otherwise, store the unmodified value in the result vector.
354-
auto convert = [&](ValueRange values,
355-
llvm::function_ref<Value(Value, int64_t)> func) {
356-
SmallVector<Value> result;
357-
for (const auto &it : llvm::enumerate(values)) {
358-
size_t idx = it.index();
359-
Value val = it.value();
360-
result.push_back(indices.contains(idx) ? func(val, idx) : val);
361-
}
362-
return result;
363-
};
472+
// The new memref init_args of the loop.
473+
SmallVector<Value> initArgs =
474+
getBuffers(rewriter, forOp.getIterOpOperands(), state);
475+
if (initArgs.size() != indices.size())
476+
return failure();
364477

365478
// Construct a new scf.for op with memref instead of tensor values.
366-
SmallVector<Value> initArgs;
367-
for (OpOperand &opOperand : forOp.getIterOpOperands()) {
368-
if (opOperand.get().getType().isa<TensorType>()) {
369-
FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand);
370-
if (failed(resultBuffer))
371-
return failure();
372-
initArgs.push_back(*resultBuffer);
373-
} else {
374-
initArgs.push_back(opOperand.get());
375-
}
376-
}
377479
auto newForOp = rewriter.create<scf::ForOp>(
378480
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
379481
forOp.getStep(), initArgs);
482+
ValueRange initArgsRange(initArgs);
483+
TypeRange initArgsTypes(initArgsRange);
380484
Block *loopBody = &newForOp.getLoopBody().front();
381485

382486
// Set up new iter_args. The loop body uses tensors, so wrap the (memref)
383487
// iter_args of the new loop in ToTensorOps.
384488
rewriter.setInsertionPointToStart(loopBody);
385-
SmallVector<Value> iterArgs =
386-
convert(newForOp.getRegionIterArgs(), [&](Value val, int64_t index) {
489+
SmallVector<Value> iterArgs = convertTensorValues(
490+
newForOp.getRegionIterArgs(), indices, [&](Value val, int64_t index) {
387491
return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val);
388492
});
389493
iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
@@ -399,42 +503,8 @@ struct ForOpInterface
399503
auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator());
400504
rewriter.setInsertionPoint(yieldOp);
401505
SmallVector<Value> yieldValues =
402-
convert(yieldOp.getResults(), [&](Value val, int64_t index) {
403-
Type initArgType = initArgs[index].getType();
404-
ensureToMemrefOpIsValid(val, initArgType);
405-
Value yieldedVal =
406-
bufferization::lookupBuffer(rewriter, val, state.getOptions());
407-
408-
if (equivalentYields[index])
409-
// Yielded value is equivalent to the corresponding iter_arg bbArg.
410-
// Yield the value directly. Most IR should be like that. Everything
411-
// else must be resolved with copies and is potentially inefficient.
412-
// By default, such problematic IR would already have been rejected
413-
// during `verifyAnalysis`, unless `allow-return-allocs`.
414-
return castBuffer(yieldedVal, initArgType);
415-
416-
// It is not certain that the yielded value and the iter_arg bbArg
417-
// have the same buffer. Allocate a new buffer and copy. The yielded
418-
// buffer will get deallocated by `deallocateBuffers`.
419-
420-
// TODO: There are cases in which it is not neccessary to return a new
421-
// buffer allocation. E.g., when equivalent values are yielded in a
422-
// different order. This could be resolved with copies.
423-
Optional<Value> yieldedAlloc = state.createAlloc(
424-
rewriter, val.getLoc(), yieldedVal, /*deallocMemref=*/false);
425-
// TODO: We should rollback, but for now just assume that this always
426-
// succeeds.
427-
assert(yieldedAlloc.hasValue() && "could not create alloc");
428-
LogicalResult copyStatus =
429-
bufferization::createMemCpy(rewriter, val.getLoc(), yieldedVal,
430-
*yieldedAlloc, state.getOptions());
431-
(void)copyStatus;
432-
assert(succeeded(copyStatus) && "could not create memcpy");
433-
434-
// The iter_arg memref type may have a layout map. Cast the new buffer
435-
// to the same type if needed.
436-
return castBuffer(*yieldedAlloc, initArgType);
437-
});
506+
getYieldedValues(rewriter, yieldOp.getResults(), initArgsTypes, indices,
507+
equivalentYields, state);
438508
yieldOp.getResultsMutable().assign(yieldValues);
439509

440510
// Replace loop results.

0 commit comments

Comments
 (0)