@@ -259,6 +259,150 @@ struct IfOpInterface
259
259
}
260
260
};
261
261
262
+ // / Helper function for loop bufferization. Return the indices of all values
263
+ // / that have a tensor type.
264
+ static DenseSet<int64_t > getTensorIndices (ValueRange values) {
265
+ DenseSet<int64_t > result;
266
+ for (const auto &it : llvm::enumerate (values))
267
+ if (it.value ().getType ().isa <TensorType>())
268
+ result.insert (it.index ());
269
+ return result;
270
+ }
271
+
272
+ // / Helper function for loop bufferization. Return the indices of all
273
+ // / bbArg/yielded value pairs who's buffer relation is "Equivalent".
274
+ DenseSet<int64_t > getEquivalentBuffers (ValueRange bbArgs,
275
+ ValueRange yieldedValues,
276
+ const AnalysisState &state) {
277
+ DenseSet<int64_t > result;
278
+ int64_t counter = 0 ;
279
+ for (const auto &it : llvm::zip (bbArgs, yieldedValues)) {
280
+ if (!std::get<0 >(it).getType ().isa <TensorType>())
281
+ continue ;
282
+ if (state.areEquivalentBufferizedValues (std::get<0 >(it), std::get<1 >(it)))
283
+ result.insert (counter);
284
+ counter++;
285
+ }
286
+ return result;
287
+ }
288
+
289
+ // / Helper function for loop bufferization. Cast the given buffer to the given
290
+ // / memref type.
291
+ static Value castBuffer (OpBuilder &b, Value buffer, Type type) {
292
+ assert (type.isa <BaseMemRefType>() && " expected BaseMemRefType" );
293
+ assert (buffer.getType ().isa <BaseMemRefType>() && " expected BaseMemRefType" );
294
+ // If the buffer already has the correct type, no cast is needed.
295
+ if (buffer.getType () == type)
296
+ return buffer;
297
+ // TODO: In case `type` has a layout map that is not the fully dynamic
298
+ // one, we may not be able to cast the buffer. In that case, the loop
299
+ // iter_arg's layout map must be changed (see uses of `castBuffer`).
300
+ assert (memref::CastOp::areCastCompatible (buffer.getType (), type) &&
301
+ " scf.while op bufferization: cast incompatible" );
302
+ return b.create <memref::CastOp>(buffer.getLoc (), type, buffer).getResult ();
303
+ }
304
+
305
+ // / Helper function for loop bufferization. Return the bufferized values of the
306
+ // / given OpOperands. If an operand is not a tensor, return the original value.
307
+ static SmallVector<Value> getBuffers (RewriterBase &rewriter,
308
+ MutableArrayRef<OpOperand> operands,
309
+ BufferizationState &state) {
310
+ SmallVector<Value> result;
311
+ for (OpOperand &opOperand : operands) {
312
+ if (opOperand.get ().getType ().isa <TensorType>()) {
313
+ FailureOr<Value> resultBuffer = state.getBuffer (rewriter, opOperand);
314
+ if (failed (resultBuffer))
315
+ return {};
316
+ result.push_back (*resultBuffer);
317
+ } else {
318
+ result.push_back (opOperand.get ());
319
+ }
320
+ }
321
+ return result;
322
+ }
323
+
324
+ // / Helper function for loop bufferization. Compute the buffer that should be
325
+ // / yielded from a loop block (loop body or loop condition). If the given tensor
326
+ // / is equivalent to the corresponding block argument (as indicated by
327
+ // / `isEquivalent`), the buffer can be yielded directly. Otherwise, a new buffer
328
+ // / copy must be yielded.
329
+ // /
330
+ // / According to the `BufferizableOpInterface` implementation of scf loops, a
331
+ // / a bufferized OpResult may alias only with the corresponding bufferized
332
+ // / init_arg and with no other buffers. I.e., the i-th OpResult may alias with
333
+ // / the i-th init_arg; but not with any other OpOperand. If a corresponding
334
+ // / OpResult/init_arg pair bufferized to equivalent buffers (as indicated by
335
+ // / `isEquivalent`), this aliasing requirement is satisfied. Otherwise, we
336
+ // / cannot be sure and must yield a new buffer copy. (New buffer copies do not
337
+ // / alias with any buffer.)
338
+ static Value getYieldedBuffer (RewriterBase &rewriter, Value tensor,
339
+ BaseMemRefType type, bool isEquivalent,
340
+ BufferizationState &state) {
341
+ assert (tensor.getType ().isa <TensorType>() && " expected tensor" );
342
+ ensureToMemrefOpIsValid (tensor, type);
343
+ Value yieldedVal =
344
+ bufferization::lookupBuffer (rewriter, tensor, state.getOptions ());
345
+
346
+ if (isEquivalent)
347
+ // Yielded value is equivalent to the corresponding iter_arg bbArg.
348
+ // Yield the value directly. Most IR should be like that. Everything
349
+ // else must be resolved with copies and is potentially inefficient.
350
+ // By default, such problematic IR would already have been rejected
351
+ // during `verifyAnalysis`, unless `allow-return-allocs`.
352
+ return castBuffer (rewriter, yieldedVal, type);
353
+
354
+ // It is not certain that the yielded value and the iter_arg bbArg
355
+ // have the same buffer. Allocate a new buffer and copy. The yielded
356
+ // buffer will get deallocated by `deallocateBuffers`.
357
+
358
+ // TODO: There are cases in which it is not neccessary to return a new
359
+ // buffer allocation. E.g., when equivalent values are yielded in a
360
+ // different order. This could be resolved with copies.
361
+ Optional<Value> yieldedAlloc = state.createAlloc (
362
+ rewriter, tensor.getLoc (), yieldedVal, /* deallocMemref=*/ false );
363
+ // TODO: We should rollback, but for now just assume that this always
364
+ // succeeds.
365
+ assert (yieldedAlloc.hasValue () && " could not create alloc" );
366
+ LogicalResult copyStatus = bufferization::createMemCpy (
367
+ rewriter, tensor.getLoc (), yieldedVal, *yieldedAlloc, state.getOptions ());
368
+ (void )copyStatus;
369
+ assert (succeeded (copyStatus) && " could not create memcpy" );
370
+
371
+ // The iter_arg memref type may have a layout map. Cast the new buffer
372
+ // to the same type if needed.
373
+ return castBuffer (rewriter, *yieldedAlloc, type);
374
+ }
375
+
376
+ // / Helper function for loop bufferization. Given a range of values, apply
377
+ // / `func` to those marked in `tensorIndices`. Otherwise, store the unmodified
378
+ // / value in the result vector.
379
+ static SmallVector<Value>
380
+ convertTensorValues (ValueRange values, const DenseSet<int64_t > &tensorIndices,
381
+ llvm::function_ref<Value(Value, int64_t )> func) {
382
+ SmallVector<Value> result;
383
+ for (const auto &it : llvm::enumerate (values)) {
384
+ size_t idx = it.index ();
385
+ Value val = it.value ();
386
+ result.push_back (tensorIndices.contains (idx) ? func (val, idx) : val);
387
+ }
388
+ return result;
389
+ }
390
+
391
+ // / Helper function for loop bufferization. Given a list of pre-bufferization
392
+ // / yielded values, compute the list of bufferized yielded values.
393
+ SmallVector<Value> getYieldedValues (RewriterBase &rewriter, ValueRange values,
394
+ TypeRange bufferizedTypes,
395
+ const DenseSet<int64_t > &tensorIndices,
396
+ const DenseSet<int64_t > &equivalentTensors,
397
+ BufferizationState &state) {
398
+ return convertTensorValues (
399
+ values, tensorIndices, [&](Value val, int64_t index) {
400
+ return getYieldedBuffer (rewriter, val,
401
+ bufferizedTypes[index].cast <BaseMemRefType>(),
402
+ equivalentTensors.contains (index), state);
403
+ });
404
+ }
405
+
262
406
// / Bufferization of scf.for. Replace with a new scf.for that operates on
263
407
// / memrefs.
264
408
struct ForOpInterface
@@ -312,78 +456,38 @@ struct ForOpInterface
312
456
LogicalResult bufferize (Operation *op, RewriterBase &rewriter,
313
457
BufferizationState &state) const {
314
458
auto forOp = cast<scf::ForOp>(op);
315
- auto bufferizableOp = cast<BufferizableOpInterface>(op);
459
+ auto oldYieldOp =
460
+ cast<scf::YieldOp>(forOp.getLoopBody ().front ().getTerminator ());
316
461
Block *oldLoopBody = &forOp.getLoopBody ().front ();
317
462
318
- // Helper function for casting MemRef buffers.
319
- auto castBuffer = [&](Value buffer, Type type) {
320
- assert (type.isa <BaseMemRefType>() && " expected BaseMemRefType" );
321
- assert (buffer.getType ().isa <BaseMemRefType>() &&
322
- " expected BaseMemRefType" );
323
- // If the buffer already has the correct type, no cast is needed.
324
- if (buffer.getType () == type)
325
- return buffer;
326
- // TODO: In case `type` has a layout map that is not the fully dynamic
327
- // one, we may not be able to cast the buffer. In that case, the loop
328
- // iter_arg's layout map must be changed (see uses of `castBuffer`).
329
- assert (memref::CastOp::areCastCompatible (buffer.getType (), type) &&
330
- " scf.for op bufferization: cast incompatible" );
331
- return rewriter.create <memref::CastOp>(buffer.getLoc (), type, buffer)
332
- .getResult ();
333
- };
334
-
335
463
// Indices of all iter_args that have tensor type. These are the ones that
336
464
// are bufferized.
337
- DenseSet<int64_t > indices;
465
+ DenseSet<int64_t > indices = getTensorIndices (forOp. getInitArgs ()) ;
338
466
// For every yielded value, is the value equivalent to its corresponding
339
467
// bbArg?
340
- SmallVector<bool > equivalentYields;
341
- for (const auto &it : llvm::enumerate (forOp.getInitArgs ())) {
342
- if (it.value ().getType ().isa <TensorType>()) {
343
- indices.insert (it.index ());
344
- BufferRelation relation = bufferizableOp.bufferRelation (
345
- forOp->getResult (it.index ()), state.getAnalysisState ());
346
- equivalentYields.push_back (relation == BufferRelation::Equivalent);
347
- } else {
348
- equivalentYields.push_back (false );
349
- }
350
- }
468
+ DenseSet<int64_t > equivalentYields =
469
+ getEquivalentBuffers (forOp.getRegionIterArgs (), oldYieldOp.getResults (),
470
+ state.getAnalysisState ());
351
471
352
- // Given a range of values, apply `func` to those marked in `indices`.
353
- // Otherwise, store the unmodified value in the result vector.
354
- auto convert = [&](ValueRange values,
355
- llvm::function_ref<Value (Value, int64_t )> func) {
356
- SmallVector<Value> result;
357
- for (const auto &it : llvm::enumerate (values)) {
358
- size_t idx = it.index ();
359
- Value val = it.value ();
360
- result.push_back (indices.contains (idx) ? func (val, idx) : val);
361
- }
362
- return result;
363
- };
472
+ // The new memref init_args of the loop.
473
+ SmallVector<Value> initArgs =
474
+ getBuffers (rewriter, forOp.getIterOpOperands (), state);
475
+ if (initArgs.size () != indices.size ())
476
+ return failure ();
364
477
365
478
// Construct a new scf.for op with memref instead of tensor values.
366
- SmallVector<Value> initArgs;
367
- for (OpOperand &opOperand : forOp.getIterOpOperands ()) {
368
- if (opOperand.get ().getType ().isa <TensorType>()) {
369
- FailureOr<Value> resultBuffer = state.getBuffer (rewriter, opOperand);
370
- if (failed (resultBuffer))
371
- return failure ();
372
- initArgs.push_back (*resultBuffer);
373
- } else {
374
- initArgs.push_back (opOperand.get ());
375
- }
376
- }
377
479
auto newForOp = rewriter.create <scf::ForOp>(
378
480
forOp.getLoc (), forOp.getLowerBound (), forOp.getUpperBound (),
379
481
forOp.getStep (), initArgs);
482
+ ValueRange initArgsRange (initArgs);
483
+ TypeRange initArgsTypes (initArgsRange);
380
484
Block *loopBody = &newForOp.getLoopBody ().front ();
381
485
382
486
// Set up new iter_args. The loop body uses tensors, so wrap the (memref)
383
487
// iter_args of the new loop in ToTensorOps.
384
488
rewriter.setInsertionPointToStart (loopBody);
385
- SmallVector<Value> iterArgs =
386
- convert ( newForOp.getRegionIterArgs (), [&](Value val, int64_t index) {
489
+ SmallVector<Value> iterArgs = convertTensorValues (
490
+ newForOp.getRegionIterArgs (), indices , [&](Value val, int64_t index) {
387
491
return rewriter.create <bufferization::ToTensorOp>(val.getLoc (), val);
388
492
});
389
493
iterArgs.insert (iterArgs.begin (), newForOp.getInductionVar ());
@@ -399,42 +503,8 @@ struct ForOpInterface
399
503
auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator ());
400
504
rewriter.setInsertionPoint (yieldOp);
401
505
SmallVector<Value> yieldValues =
402
- convert (yieldOp.getResults (), [&](Value val, int64_t index) {
403
- Type initArgType = initArgs[index].getType ();
404
- ensureToMemrefOpIsValid (val, initArgType);
405
- Value yieldedVal =
406
- bufferization::lookupBuffer (rewriter, val, state.getOptions ());
407
-
408
- if (equivalentYields[index])
409
- // Yielded value is equivalent to the corresponding iter_arg bbArg.
410
- // Yield the value directly. Most IR should be like that. Everything
411
- // else must be resolved with copies and is potentially inefficient.
412
- // By default, such problematic IR would already have been rejected
413
- // during `verifyAnalysis`, unless `allow-return-allocs`.
414
- return castBuffer (yieldedVal, initArgType);
415
-
416
- // It is not certain that the yielded value and the iter_arg bbArg
417
- // have the same buffer. Allocate a new buffer and copy. The yielded
418
- // buffer will get deallocated by `deallocateBuffers`.
419
-
420
- // TODO: There are cases in which it is not neccessary to return a new
421
- // buffer allocation. E.g., when equivalent values are yielded in a
422
- // different order. This could be resolved with copies.
423
- Optional<Value> yieldedAlloc = state.createAlloc (
424
- rewriter, val.getLoc (), yieldedVal, /* deallocMemref=*/ false );
425
- // TODO: We should rollback, but for now just assume that this always
426
- // succeeds.
427
- assert (yieldedAlloc.hasValue () && " could not create alloc" );
428
- LogicalResult copyStatus =
429
- bufferization::createMemCpy (rewriter, val.getLoc (), yieldedVal,
430
- *yieldedAlloc, state.getOptions ());
431
- (void )copyStatus;
432
- assert (succeeded (copyStatus) && " could not create memcpy" );
433
-
434
- // The iter_arg memref type may have a layout map. Cast the new buffer
435
- // to the same type if needed.
436
- return castBuffer (*yieldedAlloc, initArgType);
437
- });
506
+ getYieldedValues (rewriter, yieldOp.getResults (), initArgsTypes, indices,
507
+ equivalentYields, state);
438
508
yieldOp.getResultsMutable ().assign (yieldValues);
439
509
440
510
// Replace loop results.
0 commit comments