@@ -444,9 +444,12 @@ struct ForOpInterface
444
444
}
445
445
446
446
// / Assert that yielded values of an scf.for op are equivalent to their
447
- // / corresponding bbArgs. Otherwise, an alloc+copy are inserted and yielded
448
- // / from the loop. This could be a performance problem, so it must be
449
- // / explicitly activated with `alloc-return-allocs`.
447
+ // / corresponding bbArgs. In that case, the buffer relations of the
448
+ // / corresponding OpResults are "Equivalent".
449
+ // /
450
+ // / If this is not the case, an allocs+copies are inserted and yielded from
451
+ // / the loop. This could be a performance problem, so it must be explicitly
452
+ // / activated with `alloc-return-allocs`.
450
453
LogicalResult verifyAnalysis (Operation *op,
451
454
const AnalysisState &state) const {
452
455
const auto &options =
@@ -457,22 +460,19 @@ struct ForOpInterface
457
460
auto forOp = cast<scf::ForOp>(op);
458
461
auto yieldOp =
459
462
cast<scf::YieldOp>(forOp.getLoopBody ().front ().getTerminator ());
460
- for (OpOperand &operand : yieldOp->getOpOperands ()) {
461
- auto tensorType = operand.get ().getType ().dyn_cast <TensorType>();
462
- if (!tensorType)
463
+ for (OpResult opResult : op->getOpResults ()) {
464
+ if (!opResult.getType ().isa <TensorType>())
463
465
continue ;
464
466
465
- OpOperand &forOperand = forOp.getOpOperandForResult (
466
- forOp->getResult (operand.getOperandNumber ()));
467
- auto bbArg = forOp.getRegionIterArgForOpOperand (forOperand);
468
467
// Note: This is overly strict. We should check for aliasing bufferized
469
468
// values. But we don't have a "must-alias" analysis yet.
470
- if (!state. areEquivalentBufferizedValues (operand. get (), bbArg) )
469
+ if (bufferRelation (op, opResult, state) != BufferRelation::Equivalent )
471
470
return yieldOp->emitError ()
472
- << " Yield operand #" << operand. getOperandNumber ()
471
+ << " Yield operand #" << opResult. getResultNumber ()
473
472
<< " does not bufferize to a buffer that is aliasing the "
474
473
" matching enclosing scf::for operand" ;
475
474
}
475
+
476
476
return success ();
477
477
}
478
478
};
0 commit comments