@@ -234,18 +234,19 @@ def generate_steps(self) -> Iterable[MigrationStep]:
234
234
"""
235
235
# pre-compute incoming keys for best performance of self._required_step_ids
236
236
incoming = self ._invert_outgoing_to_incoming ()
237
- incoming_counts = self ._compute_incoming_counts (incoming )
238
- queue = self . _create_node_queue ( incoming_counts )
237
+ queue = self ._create_node_queue (incoming )
238
+ seen = set [ MigrationNode ]( )
239
239
node = queue .get ()
240
240
step_number = 1
241
241
ordered_steps : list [MigrationStep ] = []
242
242
while node is not None :
243
243
step = node .as_step (step_number , sorted (n .node_id for n in incoming [node .key ]))
244
244
ordered_steps .append (step )
245
- # Update queue priorities
245
+ seen .add (node )
246
+ # Update the queue priority as if the migration step was completed
246
247
for dependency in self ._outgoing [node .key ]:
247
- incoming_counts [dependency .key ] -= 1
248
- queue .update (incoming_counts [ dependency . key ] , dependency )
248
+ priority = len ( incoming [dependency .key ] - seen )
249
+ queue .update (priority , dependency )
249
250
step_number += 1
250
251
node = queue .get ()
251
252
return ordered_steps
@@ -257,21 +258,13 @@ def _invert_outgoing_to_incoming(self) -> dict[MigrationNodeKey, set[MigrationNo
257
258
result [target .key ].add (self ._nodes [node_key ])
258
259
return result
259
260
260
- def _compute_incoming_counts (
261
- self , incoming : dict [MigrationNodeKey , set [MigrationNode ]]
262
- ) -> dict [MigrationNodeKey , int ]:
263
- result = defaultdict (int )
264
- for node_key in self ._nodes .keys ():
265
- result [node_key ] = len (incoming [node_key ])
266
- return result
267
-
268
- def _create_node_queue (self , incoming_counts : dict [MigrationNodeKey , int ]) -> PriorityQueue :
261
+ def _create_node_queue (self , incoming : dict [MigrationNodeKey , set [MigrationNode ]]) -> PriorityQueue :
269
262
"""Create a priority queue for their nodes using the incoming count as priority.
270
263
271
264
A lower number means it is pulled from the queue first, i.e. the key with the lowest number of keys is retrieved
272
265
first.
273
266
"""
274
267
priority_queue = PriorityQueue ()
275
- for node_key , incoming_count in incoming_counts .items ():
276
- priority_queue .put (incoming_count , self ._nodes [node_key ])
268
+ for node_key , incoming_nodes in incoming .items ():
269
+ priority_queue .put (len ( incoming_nodes ) , self ._nodes [node_key ])
277
270
return priority_queue
0 commit comments