Skip to content

Commit 5a531b1

Browse files
authored
[mlir] NFC: Add data flow analysis extension points (#142549)
This commit introduces `visitCallOperation` and `visitCallableOperation` extension points in the sparse data flow analysis framework. This allows, for example, to make the analysis less conservative, without a lot of code duplication, propagating information even if not all the call or return sites are known.
1 parent 9411b00 commit 5a531b1

File tree

2 files changed

+114
-66
lines changed

2 files changed

+114
-66
lines changed

mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,30 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
235235
/// Join the lattice element and propagate and update if it changed.
236236
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
237237

238+
/// Visits a call operation. Given the operand lattices, sets the result
239+
/// lattices. Performs interprocedural data flow as follows: if the call
240+
/// operation targets an external function, or if the solver is not
241+
/// interprocedural, attempts to infer the results from the call arguments
242+
/// using the user-provided `visitExternalCallImpl`. Otherwise, computes the
243+
/// result lattices from the return sites if all return sites are known;
244+
/// otherwise, conservatively marks the result lattices as having reached
245+
/// their pessimistic fixpoints.
246+
/// This method can be overridden to, for example, be less conservative and
247+
/// propagate the information even if some return sites are unknown.
248+
virtual LogicalResult
249+
visitCallOperation(CallOpInterface call,
250+
ArrayRef<const AbstractSparseLattice *> operandLattices,
251+
ArrayRef<AbstractSparseLattice *> resultLattices);
252+
253+
/// Visits a callable operation. Computes the argument lattices from call
254+
/// sites if all call sites are known; otherwise, conservatively marks them
255+
/// as having reached their pessimistic fixpoints.
256+
/// This method can be overridden to, for example, be less conservative and
257+
/// propagate the information even if some call sites are unknown.
258+
virtual void
259+
visitCallableOperation(CallableOpInterface callable,
260+
ArrayRef<AbstractSparseLattice *> argLattices);
261+
238262
private:
239263
/// Recursively initialize the analysis on nested operations and blocks.
240264
LogicalResult initializeRecursively(Operation *op);
@@ -430,6 +454,16 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
430454
/// Join the lattice element and propagate and update if it changed.
431455
void meet(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
432456

457+
/// Visits a callable operation. If all the call sites are known computes the
458+
/// operand lattices of `op` from the result lattices of all the call sites;
459+
/// otherwise, conservatively marks them as having reached their pessimistic
460+
/// fixpoints.
461+
/// This method can be overridden to, for example, be less conservative and
462+
/// propagate the information even if some call sites are unknown.
463+
virtual LogicalResult
464+
visitCallableOperation(Operation *op, CallableOpInterface callable,
465+
ArrayRef<AbstractSparseLattice *> operandLattices);
466+
433467
private:
434468
/// Recursively initialize the analysis on nested operations and blocks.
435469
LogicalResult initializeRecursively(Operation *op);

mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp

Lines changed: 80 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -128,34 +128,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
128128
operandLattices.push_back(operandLattice);
129129
}
130130

131-
if (auto call = dyn_cast<CallOpInterface>(op)) {
132-
// If the call operation is to an external function, attempt to infer the
133-
// results from the call arguments.
134-
auto callable =
135-
dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
136-
if (!getSolverConfig().isInterprocedural() ||
137-
(callable && !callable.getCallableRegion())) {
138-
visitExternalCallImpl(call, operandLattices, resultLattices);
139-
return success();
140-
}
141-
142-
// Otherwise, the results of a call operation are determined by the
143-
// callgraph.
144-
const auto *predecessors = getOrCreateFor<PredecessorState>(
145-
getProgramPointAfter(op), getProgramPointAfter(call));
146-
// If not all return sites are known, then conservatively assume we can't
147-
// reason about the data-flow.
148-
if (!predecessors->allPredecessorsKnown()) {
149-
setAllToEntryStates(resultLattices);
150-
return success();
151-
}
152-
for (Operation *predecessor : predecessors->getKnownPredecessors())
153-
for (auto &&[operand, resLattice] :
154-
llvm::zip(predecessor->getOperands(), resultLattices))
155-
join(resLattice,
156-
*getLatticeElementFor(getProgramPointAfter(op), operand));
157-
return success();
158-
}
131+
if (auto call = dyn_cast<CallOpInterface>(op))
132+
return visitCallOperation(call, operandLattices, resultLattices);
159133

160134
// Invoke the operation transfer function.
161135
return visitOperationImpl(op, operandLattices, resultLattices);
@@ -183,24 +157,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
183157
if (block->isEntryBlock()) {
184158
// Check if this block is the entry block of a callable region.
185159
auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
186-
if (callable && callable.getCallableRegion() == block->getParent()) {
187-
const auto *callsites = getOrCreateFor<PredecessorState>(
188-
getProgramPointBefore(block), getProgramPointAfter(callable));
189-
// If not all callsites are known, conservatively mark all lattices as
190-
// having reached their pessimistic fixpoints.
191-
if (!callsites->allPredecessorsKnown() ||
192-
!getSolverConfig().isInterprocedural()) {
193-
return setAllToEntryStates(argLattices);
194-
}
195-
for (Operation *callsite : callsites->getKnownPredecessors()) {
196-
auto call = cast<CallOpInterface>(callsite);
197-
for (auto it : llvm::zip(call.getArgOperands(), argLattices))
198-
join(std::get<1>(it),
199-
*getLatticeElementFor(getProgramPointBefore(block),
200-
std::get<0>(it)));
201-
}
202-
return;
203-
}
160+
if (callable && callable.getCallableRegion() == block->getParent())
161+
return visitCallableOperation(callable, argLattices);
204162

205163
// Check if the lattices can be determined from region control flow.
206164
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
@@ -248,6 +206,59 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
248206
}
249207
}
250208

209+
LogicalResult AbstractSparseForwardDataFlowAnalysis::visitCallOperation(
210+
CallOpInterface call,
211+
ArrayRef<const AbstractSparseLattice *> operandLattices,
212+
ArrayRef<AbstractSparseLattice *> resultLattices) {
213+
// If the call operation is to an external function, attempt to infer the
214+
// results from the call arguments.
215+
auto callable =
216+
dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
217+
if (!getSolverConfig().isInterprocedural() ||
218+
(callable && !callable.getCallableRegion())) {
219+
visitExternalCallImpl(call, operandLattices, resultLattices);
220+
return success();
221+
}
222+
223+
// Otherwise, the results of a call operation are determined by the
224+
// callgraph.
225+
const auto *predecessors = getOrCreateFor<PredecessorState>(
226+
getProgramPointAfter(call), getProgramPointAfter(call));
227+
// If not all return sites are known, then conservatively assume we can't
228+
// reason about the data-flow.
229+
if (!predecessors->allPredecessorsKnown()) {
230+
setAllToEntryStates(resultLattices);
231+
return success();
232+
}
233+
for (Operation *predecessor : predecessors->getKnownPredecessors())
234+
for (auto &&[operand, resLattice] :
235+
llvm::zip(predecessor->getOperands(), resultLattices))
236+
join(resLattice,
237+
*getLatticeElementFor(getProgramPointAfter(call), operand));
238+
return success();
239+
}
240+
241+
void AbstractSparseForwardDataFlowAnalysis::visitCallableOperation(
242+
CallableOpInterface callable,
243+
ArrayRef<AbstractSparseLattice *> argLattices) {
244+
Block *entryBlock = &callable.getCallableRegion()->front();
245+
const auto *callsites = getOrCreateFor<PredecessorState>(
246+
getProgramPointBefore(entryBlock), getProgramPointAfter(callable));
247+
// If not all callsites are known, conservatively mark all lattices as
248+
// having reached their pessimistic fixpoints.
249+
if (!callsites->allPredecessorsKnown() ||
250+
!getSolverConfig().isInterprocedural()) {
251+
return setAllToEntryStates(argLattices);
252+
}
253+
for (Operation *callsite : callsites->getKnownPredecessors()) {
254+
auto call = cast<CallOpInterface>(callsite);
255+
for (auto it : llvm::zip(call.getArgOperands(), argLattices))
256+
join(std::get<1>(it),
257+
*getLatticeElementFor(getProgramPointBefore(entryBlock),
258+
std::get<0>(it)));
259+
}
260+
}
261+
251262
void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
252263
ProgramPoint *point, RegionBranchOpInterface branch,
253264
RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
@@ -512,31 +523,34 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
512523
if (op->hasTrait<OpTrait::ReturnLike>()) {
513524
// Going backwards, the operands of the return are derived from the
514525
// results of all CallOps calling this CallableOp.
515-
if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
516-
const PredecessorState *callsites = getOrCreateFor<PredecessorState>(
517-
getProgramPointAfter(op), getProgramPointAfter(callable));
518-
if (callsites->allPredecessorsKnown()) {
519-
for (Operation *call : callsites->getKnownPredecessors()) {
520-
SmallVector<const AbstractSparseLattice *> callResultLattices =
521-
getLatticeElementsFor(getProgramPointAfter(op),
522-
call->getResults());
523-
for (auto [op, result] :
524-
llvm::zip(operandLattices, callResultLattices))
525-
meet(op, *result);
526-
}
527-
} else {
528-
// If we don't know all the callers, we can't know where the
529-
// returned values go. Note that, in particular, this will trigger
530-
// for the return ops of any public functions.
531-
setAllToExitStates(operandLattices);
532-
}
533-
return success();
534-
}
526+
if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp()))
527+
return visitCallableOperation(op, callable, operandLattices);
535528
}
536529

537530
return visitOperationImpl(op, operandLattices, resultLattices);
538531
}
539532

533+
LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallableOperation(
534+
Operation *op, CallableOpInterface callable,
535+
ArrayRef<AbstractSparseLattice *> operandLattices) {
536+
const PredecessorState *callsites = getOrCreateFor<PredecessorState>(
537+
getProgramPointAfter(op), getProgramPointAfter(callable));
538+
if (callsites->allPredecessorsKnown()) {
539+
for (Operation *call : callsites->getKnownPredecessors()) {
540+
SmallVector<const AbstractSparseLattice *> callResultLattices =
541+
getLatticeElementsFor(getProgramPointAfter(op), call->getResults());
542+
for (auto [op, result] : llvm::zip(operandLattices, callResultLattices))
543+
meet(op, *result);
544+
}
545+
} else {
546+
// If we don't know all the callers, we can't know where the
547+
// returned values go. Note that, in particular, this will trigger
548+
// for the return ops of any public functions.
549+
setAllToExitStates(operandLattices);
550+
}
551+
return success();
552+
}
553+
540554
void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
541555
RegionBranchOpInterface branch,
542556
ArrayRef<AbstractSparseLattice *> operandLattices) {

0 commit comments

Comments
 (0)