4
4
from collections import defaultdict
5
5
from collections .abc import Iterable
6
6
from dataclasses import dataclass , field
7
- from typing import Generic , TypeVar
8
7
9
8
from databricks .sdk import WorkspaceClient
10
9
from databricks .sdk .service import jobs
@@ -78,11 +77,13 @@ def as_step(self, step_number: int, required_step_ids: list[int]) -> MigrationSt
78
77
)
79
78
80
79
81
- QueueTask = TypeVar ("QueueTask" )
82
- QueueEntry = list [int , int , QueueTask | str ] # type: ignore
80
+ # We expect `tuple[int, int, MigrationNode | str]`
81
+ # for `[priority, counter, MigrationNode | PriorityQueue._REMOVED | PriorityQueue_UPDATED]`
82
+ # but we use list for the required mutability
83
+ QueueEntry = list [int | MigrationNode | str ]
83
84
84
85
85
- class PriorityQueue ( Generic [ QueueTask ]) :
86
+ class PriorityQueue :
86
87
"""A priority queue supporting to update tasks.
87
88
88
89
An adaption from class:queue.Priority to support updating tasks.
@@ -100,38 +101,39 @@ class PriorityQueue(Generic[QueueTask]):
100
101
101
102
def __init__ (self ):
102
103
self ._entries : list [QueueEntry ] = []
103
- self ._entry_finder : dict [QueueTask , QueueEntry ] = {}
104
+ self ._entry_finder : dict [MigrationNode , QueueEntry ] = {}
104
105
self ._counter = 0 # Tiebreaker with equal priorities, then "first in, first out"
105
106
106
- def put (self , priority : int , task : QueueTask ) -> None :
107
+ def put (self , priority : int , task : MigrationNode ) -> None :
107
108
"""Put or update task in the queue.
108
109
109
110
The lowest priority is retrieved from the queue first.
110
111
"""
111
112
if task in self ._entry_finder :
112
113
raise KeyError (f"Use `:meth:update` to update existing task: { task } " )
113
- entry = [priority , self ._counter , task ]
114
+ entry : QueueEntry = [priority , self ._counter , task ]
114
115
self ._entry_finder [task ] = entry
115
116
heapq .heappush (self ._entries , entry )
116
117
self ._counter += 1
117
118
118
- def get (self ) -> QueueTask | None :
119
+ def get (self ) -> MigrationNode | None :
119
120
"""Gets the tasks with lowest priority."""
120
121
while self ._entries :
121
122
_ , _ , task = heapq .heappop (self ._entries )
122
123
if task in (self ._REMOVED , self ._UPDATED ):
123
124
continue
125
+ assert isinstance (task , MigrationNode )
124
126
self ._remove (task )
125
127
# Ignore type because heappop returns Any, while we know it is an QueueEntry
126
- return task # type: ignore
128
+ return task
127
129
return None
128
130
129
- def _remove (self , task : QueueTask ) -> None :
131
+ def _remove (self , task : MigrationNode ) -> None :
130
132
"""Remove a task from the queue."""
131
133
entry = self ._entry_finder .pop (task )
132
134
entry [2 ] = self ._REMOVED
133
135
134
- def update (self , priority : int , task : QueueTask ) -> None :
136
+ def update (self , priority : int , task : MigrationNode ) -> None :
135
137
"""Update a task in the queue."""
136
138
entry = self ._entry_finder .pop (task )
137
139
if entry is None :
@@ -272,7 +274,7 @@ def _create_node_queue(self, incoming_counts: dict[MigrationNodeKey, int]) -> Pr
272
274
A lower number means it is pulled from the queue first, i.e. the key with the lowest number of keys is retrieved
273
275
first.
274
276
"""
275
- priority_queue : PriorityQueue [ MigrationNode ] = PriorityQueue ()
277
+ priority_queue = PriorityQueue ()
276
278
for node_key , incoming_count in incoming_counts .items ():
277
279
priority_queue .put (incoming_count , self ._nodes [node_key ])
278
280
return priority_queue
0 commit comments