Skip to content

Commit 27cb2da

Browse files
committed
Fix PriorityQueue type hints
1 parent 86dba4c commit 27cb2da

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

src/databricks/labs/ucx/assessment/sequencing.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from collections import defaultdict
55
from collections.abc import Iterable
66
from dataclasses import dataclass, field
7-
from typing import Generic, TypeVar
87

98
from databricks.sdk import WorkspaceClient
109
from databricks.sdk.service import jobs
@@ -78,11 +77,13 @@ def as_step(self, step_number: int, required_step_ids: list[int]) -> MigrationSt
7877
)
7978

8079

81-
QueueTask = TypeVar("QueueTask")
82-
QueueEntry = list[int, int, QueueTask | str] # type: ignore
80+
# We expect `tuple[int, int, MigrationNode | str]`
81+
# for `[priority, counter, MigrationNode | PriorityQueue._REMOVED | PriorityQueue_UPDATED]`
82+
# but we use list for the required mutability
83+
QueueEntry = list[int | MigrationNode | str]
8384

8485

85-
class PriorityQueue(Generic[QueueTask]):
86+
class PriorityQueue:
8687
"""A priority queue supporting to update tasks.
8788
8889
An adaption from class:queue.Priority to support updating tasks.
@@ -100,38 +101,39 @@ class PriorityQueue(Generic[QueueTask]):
100101

101102
def __init__(self):
102103
self._entries: list[QueueEntry] = []
103-
self._entry_finder: dict[QueueTask, QueueEntry] = {}
104+
self._entry_finder: dict[MigrationNode, QueueEntry] = {}
104105
self._counter = 0 # Tiebreaker with equal priorities, then "first in, first out"
105106

106-
def put(self, priority: int, task: QueueTask) -> None:
107+
def put(self, priority: int, task: MigrationNode) -> None:
107108
"""Put or update task in the queue.
108109
109110
The lowest priority is retrieved from the queue first.
110111
"""
111112
if task in self._entry_finder:
112113
raise KeyError(f"Use `:meth:update` to update existing task: {task}")
113-
entry = [priority, self._counter, task]
114+
entry: QueueEntry = [priority, self._counter, task]
114115
self._entry_finder[task] = entry
115116
heapq.heappush(self._entries, entry)
116117
self._counter += 1
117118

118-
def get(self) -> QueueTask | None:
119+
def get(self) -> MigrationNode | None:
119120
"""Gets the tasks with lowest priority."""
120121
while self._entries:
121122
_, _, task = heapq.heappop(self._entries)
122123
if task in (self._REMOVED, self._UPDATED):
123124
continue
125+
assert isinstance(task, MigrationNode)
124126
self._remove(task)
125127
# Ignore type because heappop returns Any, while we know it is an QueueEntry
126-
return task # type: ignore
128+
return task
127129
return None
128130

129-
def _remove(self, task: QueueTask) -> None:
131+
def _remove(self, task: MigrationNode) -> None:
130132
"""Remove a task from the queue."""
131133
entry = self._entry_finder.pop(task)
132134
entry[2] = self._REMOVED
133135

134-
def update(self, priority: int, task: QueueTask) -> None:
136+
def update(self, priority: int, task: MigrationNode) -> None:
135137
"""Update a task in the queue."""
136138
entry = self._entry_finder.pop(task)
137139
if entry is None:
@@ -272,7 +274,7 @@ def _create_node_queue(self, incoming_counts: dict[MigrationNodeKey, int]) -> Pr
272274
A lower number means it is pulled from the queue first, i.e. the key with the lowest number of keys is retrieved
273275
first.
274276
"""
275-
priority_queue: PriorityQueue[MigrationNode] = PriorityQueue()
277+
priority_queue = PriorityQueue()
276278
for node_key, incoming_count in incoming_counts.items():
277279
priority_queue.put(incoming_count, self._nodes[node_key])
278280
return priority_queue

0 commit comments

Comments
 (0)