30
30
import java .util .concurrent .ConcurrentHashMap ;
31
31
import java .util .stream .Collectors ;
32
32
33
- import org .sing_group .compi .core .loops .ForeachIteration ;
34
33
import org .sing_group .compi .core .loops .ForeachIterationDependency ;
35
34
import org .sing_group .compi .core .pipeline .Foreach ;
36
35
import org .sing_group .compi .core .pipeline .Task ;
37
36
38
37
public class TasksDAG {
39
38
40
- private final Map <Task , Set <Dependency <?>>> dag = new ConcurrentHashMap <>();
39
+ private final Map <Task , Set <Dependency <?>>> dag = new ConcurrentHashMap <>(); // task
40
+ // ->
41
+ // dependants
42
+ private final Map <Task , Set <Dependency <?>>> reverseDag = new ConcurrentHashMap <>(); // task
43
+ // ->
44
+ // dependencies
41
45
42
46
private Map <Task , Set <Dependency <?>>> dependantsCache = new HashMap <>();
43
47
private Map <Task , Set <Dependency <?>>> dependenciesCache = new HashMap <>();
44
48
45
49
public Set <Dependency <?>> getDependantsOfTask (Task t ) {
50
+
46
51
if (dependantsCache .containsKey (t ))
47
52
return dependantsCache .get (t );
48
53
@@ -55,8 +60,7 @@ public Set<Dependency<?>> getDependantsOfTask(Task t) {
55
60
dependantsOfTask .add (dependency );
56
61
Set <Dependency <?>> dependants = getDependantsOfTask (dependency .getDependantTask ());
57
62
dependantsOfTask .addAll (
58
- dependants .stream ().map (d -> new Dependency <Task >(t , d .getDependantTask ())
59
- ).collect (Collectors .toSet ())
63
+ dependants .stream ().map (d -> new Dependency <Task >(t , d .getDependantTask ())).collect (Collectors .toSet ())
60
64
);
61
65
}
62
66
@@ -73,6 +77,7 @@ public Set<Dependency<?>> getDependantsOfTask(Task t) {
73
77
* @return Tasks that task depends on
74
78
*/
75
79
public Set <Dependency <?>> getDependenciesOfTask (Task task ) {
80
+
76
81
if (dependenciesCache .containsKey (task ))
77
82
return dependenciesCache .get (task );
78
83
@@ -112,14 +117,16 @@ public void initializeTaskDependencies(Collection<Task> tasks) {
112
117
}
113
118
114
119
public boolean dependenciesAreMet (final Task task ) {
115
- return this .getDependenciesOfTask (task ).stream ()
116
- .map (Dependency ::getOnTask ).filter (t -> !t .isFinished ())
117
- .collect (Collectors .toList ())
120
+ if (!reverseDag .containsKey (task )) {
121
+ return true ;
122
+ }
123
+ return this .reverseDag .get (task ).stream ().filter (d -> !d .getOnTask ().isFinished ()).collect (Collectors .toList ())
118
124
.size () == 0 ;
119
125
}
120
126
121
127
public void removeDependency (Task t , Task dependant ) {
122
128
this .dag .get (t ).removeIf (d -> d .getDependantTask ().equals (dependant ));
129
+ this .reverseDag .get (dependant ).removeIf (d -> d .getOnTask ().equals (t ));
123
130
clearCache ();
124
131
}
125
132
@@ -130,9 +137,15 @@ public void addDependency(Task task, Task dependant, boolean isIterationDependen
130
137
if (!this .dag .containsKey (task )) {
131
138
this .dag .put (task , new HashSet <>());
132
139
}
133
- this .dag .get (task ).add (
134
- isIterationDependency ? new ForeachIterationDependency ((Foreach ) task , (Foreach ) dependant ) : new Dependency <Task >(task , dependant )
135
- );
140
+ if (!this .reverseDag .containsKey (dependant )) {
141
+ this .reverseDag .put (dependant , new HashSet <>());
142
+ }
143
+
144
+ Dependency <?> dependency =
145
+ isIterationDependency ? new ForeachIterationDependency ((Foreach ) task , (Foreach ) dependant ) : new Dependency <Task >(task , dependant );
146
+ this .dag .get (task ).add (dependency );
147
+ this .reverseDag .get (dependant ).add (dependency );
148
+
136
149
clearCache ();
137
150
}
138
151
0 commit comments