@@ -128,34 +128,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
128
128
operandLattices.push_back (operandLattice);
129
129
}
130
130
131
- if (auto call = dyn_cast<CallOpInterface>(op)) {
132
- // If the call operation is to an external function, attempt to infer the
133
- // results from the call arguments.
134
- auto callable =
135
- dyn_cast_if_present<CallableOpInterface>(call.resolveCallable ());
136
- if (!getSolverConfig ().isInterprocedural () ||
137
- (callable && !callable.getCallableRegion ())) {
138
- visitExternalCallImpl (call, operandLattices, resultLattices);
139
- return success ();
140
- }
141
-
142
- // Otherwise, the results of a call operation are determined by the
143
- // callgraph.
144
- const auto *predecessors = getOrCreateFor<PredecessorState>(
145
- getProgramPointAfter (op), getProgramPointAfter (call));
146
- // If not all return sites are known, then conservatively assume we can't
147
- // reason about the data-flow.
148
- if (!predecessors->allPredecessorsKnown ()) {
149
- setAllToEntryStates (resultLattices);
150
- return success ();
151
- }
152
- for (Operation *predecessor : predecessors->getKnownPredecessors ())
153
- for (auto &&[operand, resLattice] :
154
- llvm::zip (predecessor->getOperands (), resultLattices))
155
- join (resLattice,
156
- *getLatticeElementFor (getProgramPointAfter (op), operand));
157
- return success ();
158
- }
131
+ if (auto call = dyn_cast<CallOpInterface>(op))
132
+ return visitCallOperation (call, operandLattices, resultLattices);
159
133
160
134
// Invoke the operation transfer function.
161
135
return visitOperationImpl (op, operandLattices, resultLattices);
@@ -183,24 +157,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
183
157
if (block->isEntryBlock ()) {
184
158
// Check if this block is the entry block of a callable region.
185
159
auto callable = dyn_cast<CallableOpInterface>(block->getParentOp ());
186
- if (callable && callable.getCallableRegion () == block->getParent ()) {
187
- const auto *callsites = getOrCreateFor<PredecessorState>(
188
- getProgramPointBefore (block), getProgramPointAfter (callable));
189
- // If not all callsites are known, conservatively mark all lattices as
190
- // having reached their pessimistic fixpoints.
191
- if (!callsites->allPredecessorsKnown () ||
192
- !getSolverConfig ().isInterprocedural ()) {
193
- return setAllToEntryStates (argLattices);
194
- }
195
- for (Operation *callsite : callsites->getKnownPredecessors ()) {
196
- auto call = cast<CallOpInterface>(callsite);
197
- for (auto it : llvm::zip (call.getArgOperands (), argLattices))
198
- join (std::get<1 >(it),
199
- *getLatticeElementFor (getProgramPointBefore (block),
200
- std::get<0 >(it)));
201
- }
202
- return ;
203
- }
160
+ if (callable && callable.getCallableRegion () == block->getParent ())
161
+ return visitCallableOperation (callable, argLattices);
204
162
205
163
// Check if the lattices can be determined from region control flow.
206
164
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp ())) {
@@ -248,6 +206,59 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
248
206
}
249
207
}
250
208
209
+ LogicalResult AbstractSparseForwardDataFlowAnalysis::visitCallOperation (
210
+ CallOpInterface call,
211
+ ArrayRef<const AbstractSparseLattice *> operandLattices,
212
+ ArrayRef<AbstractSparseLattice *> resultLattices) {
213
+ // If the call operation is to an external function, attempt to infer the
214
+ // results from the call arguments.
215
+ auto callable =
216
+ dyn_cast_if_present<CallableOpInterface>(call.resolveCallable ());
217
+ if (!getSolverConfig ().isInterprocedural () ||
218
+ (callable && !callable.getCallableRegion ())) {
219
+ visitExternalCallImpl (call, operandLattices, resultLattices);
220
+ return success ();
221
+ }
222
+
223
+ // Otherwise, the results of a call operation are determined by the
224
+ // callgraph.
225
+ const auto *predecessors = getOrCreateFor<PredecessorState>(
226
+ getProgramPointAfter (call), getProgramPointAfter (call));
227
+ // If not all return sites are known, then conservatively assume we can't
228
+ // reason about the data-flow.
229
+ if (!predecessors->allPredecessorsKnown ()) {
230
+ setAllToEntryStates (resultLattices);
231
+ return success ();
232
+ }
233
+ for (Operation *predecessor : predecessors->getKnownPredecessors ())
234
+ for (auto &&[operand, resLattice] :
235
+ llvm::zip (predecessor->getOperands (), resultLattices))
236
+ join (resLattice,
237
+ *getLatticeElementFor (getProgramPointAfter (call), operand));
238
+ return success ();
239
+ }
240
+
241
+ void AbstractSparseForwardDataFlowAnalysis::visitCallableOperation (
242
+ CallableOpInterface callable,
243
+ ArrayRef<AbstractSparseLattice *> argLattices) {
244
+ Block *entryBlock = &callable.getCallableRegion ()->front ();
245
+ const auto *callsites = getOrCreateFor<PredecessorState>(
246
+ getProgramPointBefore (entryBlock), getProgramPointAfter (callable));
247
+ // If not all callsites are known, conservatively mark all lattices as
248
+ // having reached their pessimistic fixpoints.
249
+ if (!callsites->allPredecessorsKnown () ||
250
+ !getSolverConfig ().isInterprocedural ()) {
251
+ return setAllToEntryStates (argLattices);
252
+ }
253
+ for (Operation *callsite : callsites->getKnownPredecessors ()) {
254
+ auto call = cast<CallOpInterface>(callsite);
255
+ for (auto it : llvm::zip (call.getArgOperands (), argLattices))
256
+ join (std::get<1 >(it),
257
+ *getLatticeElementFor (getProgramPointBefore (entryBlock),
258
+ std::get<0 >(it)));
259
+ }
260
+ }
261
+
251
262
void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors (
252
263
ProgramPoint *point, RegionBranchOpInterface branch,
253
264
RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
@@ -512,31 +523,34 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
512
523
if (op->hasTrait <OpTrait::ReturnLike>()) {
513
524
// Going backwards, the operands of the return are derived from the
514
525
// results of all CallOps calling this CallableOp.
515
- if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp ())) {
516
- const PredecessorState *callsites = getOrCreateFor<PredecessorState>(
517
- getProgramPointAfter (op), getProgramPointAfter (callable));
518
- if (callsites->allPredecessorsKnown ()) {
519
- for (Operation *call : callsites->getKnownPredecessors ()) {
520
- SmallVector<const AbstractSparseLattice *> callResultLattices =
521
- getLatticeElementsFor (getProgramPointAfter (op),
522
- call->getResults ());
523
- for (auto [op, result] :
524
- llvm::zip (operandLattices, callResultLattices))
525
- meet (op, *result);
526
- }
527
- } else {
528
- // If we don't know all the callers, we can't know where the
529
- // returned values go. Note that, in particular, this will trigger
530
- // for the return ops of any public functions.
531
- setAllToExitStates (operandLattices);
532
- }
533
- return success ();
534
- }
526
+ if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp ()))
527
+ return visitCallableOperation (op, callable, operandLattices);
535
528
}
536
529
537
530
return visitOperationImpl (op, operandLattices, resultLattices);
538
531
}
539
532
533
+ LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallableOperation (
534
+ Operation *op, CallableOpInterface callable,
535
+ ArrayRef<AbstractSparseLattice *> operandLattices) {
536
+ const PredecessorState *callsites = getOrCreateFor<PredecessorState>(
537
+ getProgramPointAfter (op), getProgramPointAfter (callable));
538
+ if (callsites->allPredecessorsKnown ()) {
539
+ for (Operation *call : callsites->getKnownPredecessors ()) {
540
+ SmallVector<const AbstractSparseLattice *> callResultLattices =
541
+ getLatticeElementsFor (getProgramPointAfter (op), call->getResults ());
542
+ for (auto [op, result] : llvm::zip (operandLattices, callResultLattices))
543
+ meet (op, *result);
544
+ }
545
+ } else {
546
+ // If we don't know all the callers, we can't know where the
547
+ // returned values go. Note that, in particular, this will trigger
548
+ // for the return ops of any public functions.
549
+ setAllToExitStates (operandLattices);
550
+ }
551
+ return success ();
552
+ }
553
+
540
554
void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors (
541
555
RegionBranchOpInterface branch,
542
556
ArrayRef<AbstractSparseLattice *> operandLattices) {
0 commit comments