@@ -251,13 +251,64 @@ bool LinalgDependenceGraph::hasDependenceFrom(
251
251
return false ;
252
252
}
253
253
254
- bool LinalgDependenceGraph::hasDependentOperations (
254
+ bool LinalgDependenceGraph::hasDependentOperationsFrom (
255
+ LinalgOp linalgOp,
256
+ ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
257
+ for (auto dep : depTypes) {
258
+ if (!getDependencesFrom (linalgOp, dep).empty ())
259
+ return true ;
260
+ }
261
+ return false ;
262
+ }
263
+
264
+ bool LinalgDependenceGraph::hasDependentOperationsInto (
255
265
LinalgOp linalgOp,
256
266
ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
257
267
for (auto dep : depTypes) {
258
- if (!getDependencesFrom (linalgOp, dep).empty () ||
259
- !getDependencesInto (linalgOp, dep).empty ())
268
+ if (!getDependencesInto (linalgOp, dep).empty ())
260
269
return true ;
261
270
}
262
271
return false ;
263
272
}
273
+
274
+ bool LinalgDependenceGraph::hasDependentOperations (
275
+ LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
276
+ return hasDependentOperationsInto (linalgOp, depTypes) ||
277
+ hasDependentOperationsFrom (linalgOp, depTypes);
278
+ }
279
+
280
+ SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2 >
281
+ LinalgDependenceGraph::getDependentOperationsInto (
282
+ LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
283
+ SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2 >
284
+ dependentOperations;
285
+ for (auto dependenceType : depTypes) {
286
+ auto dependencies = getDependencesInto (linalgOp, dependenceType);
287
+ dependentOperations.append (dependencies.begin (), dependencies.end ());
288
+ }
289
+ return dependentOperations;
290
+ }
291
+
292
+ SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2 >
293
+ LinalgDependenceGraph::getDependentOperationsFrom (
294
+ LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
295
+ SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2 >
296
+ dependentOperations;
297
+ for (auto dependenceType : depTypes) {
298
+ auto dependencies = getDependencesFrom (linalgOp, dependenceType);
299
+ dependentOperations.append (dependencies.begin (), dependencies.end ());
300
+ }
301
+ return dependentOperations;
302
+ }
303
+
304
+ // / Returns all dependent operations (into and from) given `operation`.
305
+ SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2 >
306
+ LinalgDependenceGraph::getDependentOperations (
307
+ LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
308
+ SmallVector<LinalgDependenceGraphElem, 2 > dependentOperations =
309
+ getDependentOperationsInto (linalgOp, depTypes);
310
+ SmallVector<LinalgDependenceGraphElem, 2 > t =
311
+ getDependentOperationsFrom (linalgOp, depTypes);
312
+ dependentOperations.append (t.begin (), t.end ());
313
+ return dependentOperations;
314
+ }
0 commit comments