@@ -55,7 +55,6 @@ def __init__(self, ws: WorkspaceClient, admin_locator: AdministratorLocator):
55
55
self ._admin_locator = admin_locator
56
56
self ._last_node_id = 0
57
57
self ._nodes : dict [tuple [str , str ], MigrationNode ] = {}
58
- self ._incoming : dict [tuple [str , str ], set [tuple [str , str ]]] = defaultdict (set )
59
58
self ._outgoing : dict [tuple [str , str ], set [tuple [str , str ]]] = defaultdict (set )
60
59
61
60
def register_workflow_task (self , task : jobs .Task , job : jobs .Job , _graph : DependencyGraph ) -> MigrationNode :
@@ -73,15 +72,12 @@ def register_workflow_task(self, task: jobs.Task, job: jobs.Job, _graph: Depende
73
72
object_owner = job_node .object_owner , # no task owner so use job one
74
73
)
75
74
self ._nodes [task_node .key ] = task_node
76
- self ._incoming [job_node .key ].add (task_node .key )
77
75
self ._outgoing [task_node .key ].add (job_node .key )
78
76
if task .existing_cluster_id :
79
77
cluster_node = self .register_cluster (task .existing_cluster_id )
80
78
if cluster_node :
81
- self ._incoming [cluster_node .key ].add (task_node .key )
82
79
self ._outgoing [task_node .key ].add (cluster_node .key )
83
80
# also make the cluster dependent on the job
84
- self ._incoming [cluster_node .key ].add (job_node .key )
85
81
self ._outgoing [job_node .key ].add (cluster_node .key )
86
82
# TODO register dependency graph
87
83
return task_node
@@ -104,7 +100,6 @@ def register_workflow_job(self, job: jobs.Job) -> MigrationNode:
104
100
for job_cluster in job .settings .job_clusters :
105
101
cluster_node = self .register_job_cluster (job_cluster )
106
102
if cluster_node :
107
- self ._incoming [cluster_node .key ].add (job_node .key )
108
103
self ._outgoing [job_node .key ].add (cluster_node .key )
109
104
return job_node
110
105
@@ -132,31 +127,47 @@ def register_cluster(self, cluster_id: str) -> MigrationNode:
132
127
return cluster_node
133
128
134
129
def generate_steps (self ) -> Iterable [MigrationStep ]:
135
- # algo adapted from Kahn topological sort. The main differences is that
136
- # we want the same step number for all nodes with same dependency depth
137
- # so instead of pushing to a queue, we rebuild it once all leaf nodes are processed
138
- # (these are transient leaf nodes i.e. they only become leaf during processing)
139
- incoming_counts = self ._populate_incoming_counts ()
130
+ """algo adapted from Kahn topological sort. The differences are as follows:
131
+ - we want the same step number for all nodes with same dependency depth
132
+ so instead of pushing to a queue, we rebuild it once all leaf nodes are processed
133
+ (these are transient leaf nodes i.e. they only become leaf during processing)
134
+ - the inputs do not form a DAG so we need specialized handling of edge cases
135
+ (implemented in PR #3009)
136
+ """
137
+ # pre-compute incoming keys for best performance of self._required_step_ids
138
+ incoming_keys = self ._collect_incoming_keys ()
139
+ incoming_counts = self .compute_incoming_counts (incoming_keys )
140
140
step_number = 1
141
141
sorted_steps : list [MigrationStep ] = []
142
142
while len (incoming_counts ) > 0 :
143
143
leaf_keys = list (self ._get_leaf_keys (incoming_counts ))
144
144
for leaf_key in leaf_keys :
145
145
del incoming_counts [leaf_key ]
146
- sorted_steps .append (self ._nodes [leaf_key ].as_step (step_number , list (self ._required_step_ids (leaf_key ))))
146
+ sorted_steps .append (
147
+ self ._nodes [leaf_key ].as_step (step_number , list (self ._required_step_ids (incoming_keys [leaf_key ])))
148
+ )
147
149
for dependency_key in self ._outgoing [leaf_key ]:
148
150
incoming_counts [dependency_key ] -= 1
149
151
step_number += 1
150
152
return sorted_steps
151
153
152
- def _required_step_ids (self , node_key : tuple [str , str ]) -> Iterable [int ]:
153
- for leaf_key in self ._incoming [node_key ]:
154
- yield self ._nodes [leaf_key ].node_id
154
+ def _collect_incoming_keys (self ) -> dict [tuple [str , str ], set [tuple [str , str ]]]:
155
+ result : dict [tuple [str , str ], set [tuple [str , str ]]] = defaultdict (set )
156
+ for source , outgoing in self ._outgoing .items ():
157
+ for target in outgoing :
158
+ result [target ].add (source )
159
+ return result
160
+
161
+ def _required_step_ids (self , required_step_keys : set [tuple [str , str ]]) -> Iterable [int ]:
162
+ for source_key in required_step_keys :
163
+ yield self ._nodes [source_key ].node_id
155
164
156
- def _populate_incoming_counts (self ) -> dict [tuple [str , str ], int ]:
165
+ def compute_incoming_counts (
166
+ self , incoming : dict [tuple [str , str ], set [tuple [str , str ]]]
167
+ ) -> dict [tuple [str , str ], int ]:
157
168
result = defaultdict (int )
158
169
for node_key in self ._nodes :
159
- result [node_key ] = len (self . _incoming [node_key ])
170
+ result [node_key ] = len (incoming [node_key ])
160
171
return result
161
172
162
173
@staticmethod
0 commit comments