@@ -79,7 +79,7 @@ def as_step(self, step_number: int, required_step_ids: list[int]) -> MigrationSt
79
79
80
80
81
81
QueueTask = TypeVar ("QueueTask" )
82
- QueueEntry = list [int , int , QueueTask | str ]
82
+ QueueEntry = list [int , int , QueueTask | str ] # type: ignore
83
83
84
84
85
85
class PriorityQueue :
@@ -153,7 +153,7 @@ def __init__(self, ws: WorkspaceClient, admin_locator: AdministratorLocator):
153
153
self ._admin_locator = admin_locator
154
154
self ._last_node_id = 0
155
155
self ._nodes : dict [MigrationNodeKey , MigrationNode ] = {}
156
- self ._outgoing : dict [MigrationNodeKey , set [MigrationNodeKey ]] = defaultdict (set )
156
+ self ._outgoing : dict [MigrationNodeKey , set [MigrationNode ]] = defaultdict (set )
157
157
158
158
def register_workflow_task (self , task : jobs .Task , job : jobs .Job , _graph : DependencyGraph ) -> MigrationNode :
159
159
task_id = f"{ job .job_id } /{ task .task_key } "
@@ -170,13 +170,13 @@ def register_workflow_task(self, task: jobs.Task, job: jobs.Job, _graph: Depende
170
170
object_owner = job_node .object_owner , # no task owner so use job one
171
171
)
172
172
self ._nodes [task_node .key ] = task_node
173
- self ._outgoing [task_node .key ].add (job_node . key )
173
+ self ._outgoing [task_node .key ].add (job_node )
174
174
if task .existing_cluster_id :
175
175
cluster_node = self .register_cluster (task .existing_cluster_id )
176
176
if cluster_node :
177
- self ._outgoing [task_node .key ].add (cluster_node . key )
177
+ self ._outgoing [task_node .key ].add (cluster_node )
178
178
# also make the cluster dependent on the job
179
- self ._outgoing [job_node .key ].add (cluster_node . key )
179
+ self ._outgoing [job_node .key ].add (cluster_node )
180
180
# TODO register dependency graph
181
181
return task_node
182
182
@@ -198,7 +198,7 @@ def register_workflow_job(self, job: jobs.Job) -> MigrationNode:
198
198
for job_cluster in job .settings .job_clusters :
199
199
cluster_node = self .register_job_cluster (job_cluster )
200
200
if cluster_node :
201
- self ._outgoing [job_node .key ].add (cluster_node . key )
201
+ self ._outgoing [job_node .key ].add (cluster_node )
202
202
return job_node
203
203
204
204
def register_job_cluster (self , cluster : jobs .JobCluster ) -> MigrationNode | None :
@@ -234,40 +234,35 @@ def generate_steps(self) -> Iterable[MigrationStep]:
234
234
- We handle cyclic dependencies (implemented in PR #3009)
235
235
"""
236
236
# pre-compute incoming keys for best performance of self._required_step_ids
237
- incoming_keys = self ._collect_incoming_keys ()
238
- incoming_counts = self .compute_incoming_counts ( incoming_keys )
237
+ incoming = self ._invert_outgoing_to_incoming ()
238
+ incoming_counts = self ._compute_incoming_counts ( incoming )
239
239
key_queue = self ._create_key_queue (incoming_counts )
240
240
key = key_queue .get ()
241
241
step_number = 1
242
242
sorted_steps : list [MigrationStep ] = []
243
243
while key is not None :
244
- required_step_ids = sorted (self ._get_required_step_ids (incoming_keys [key ]))
245
- step = self ._nodes [key ].as_step (step_number , required_step_ids )
244
+ step = self ._nodes [key ].as_step (step_number , sorted (n .node_id for n in incoming [key ]))
246
245
sorted_steps .append (step )
247
246
# Update queue priorities
248
- for dependency_key in self ._outgoing [key ]:
249
- incoming_counts [dependency_key ] -= 1
250
- key_queue .update (incoming_counts [dependency_key ], dependency_key )
247
+ for dependency in self ._outgoing [key ]:
248
+ incoming_counts [dependency . key ] -= 1
249
+ key_queue .update (incoming_counts [dependency . key ], dependency )
251
250
step_number += 1
252
251
key = key_queue .get ()
253
252
return sorted_steps
254
253
255
- def _collect_incoming_keys (self ) -> dict [tuple [ str , str ], set [tuple [ str , str ] ]]:
256
- result : dict [tuple [ str , str ], set [tuple [ str , str ] ]] = defaultdict (set )
257
- for source , outgoing in self ._outgoing .items ():
258
- for target in outgoing :
259
- result [target ].add (source )
254
+ def _invert_outgoing_to_incoming (self ) -> dict [MigrationNodeKey , set [MigrationNode ]]:
255
+ result : dict [MigrationNodeKey , set [MigrationNode ]] = defaultdict (set )
256
+ for node_key , outgoing_nodes in self ._outgoing .items ():
257
+ for target in outgoing_nodes :
258
+ result [target . key ].add (self . _nodes [ node_key ] )
260
259
return result
261
260
262
- def _get_required_step_ids (self , required_step_keys : set [tuple [str , str ]]) -> Iterable [int ]:
263
- for source_key in required_step_keys :
264
- yield self ._nodes [source_key ].node_id
265
-
266
- def compute_incoming_counts (
267
- self , incoming : dict [tuple [str , str ], set [tuple [str , str ]]]
268
- ) -> dict [tuple [str , str ], int ]:
261
+ def _compute_incoming_counts (
262
+ self , incoming : dict [MigrationNodeKey , set [MigrationNode ]]
263
+ ) -> dict [MigrationNodeKey , int ]:
269
264
result = defaultdict (int )
270
- for node_key in self ._nodes :
265
+ for node_key in self ._nodes . keys () :
271
266
result [node_key ] = len (incoming [node_key ])
272
267
return result
273
268
0 commit comments