Skip to content

Commit a3a491f

Browse files
committed
Make types more consistent
1 parent ce801f3 commit a3a491f

File tree

1 file changed

+21
-26
lines changed

1 file changed

+21
-26
lines changed

src/databricks/labs/ucx/assessment/sequencing.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def as_step(self, step_number: int, required_step_ids: list[int]) -> MigrationSt
7979

8080

8181
QueueTask = TypeVar("QueueTask")
82-
QueueEntry = list[int, int, QueueTask | str]
82+
QueueEntry = list[int, int, QueueTask | str] # type: ignore
8383

8484

8585
class PriorityQueue:
@@ -153,7 +153,7 @@ def __init__(self, ws: WorkspaceClient, admin_locator: AdministratorLocator):
153153
self._admin_locator = admin_locator
154154
self._last_node_id = 0
155155
self._nodes: dict[MigrationNodeKey, MigrationNode] = {}
156-
self._outgoing: dict[MigrationNodeKey, set[MigrationNodeKey]] = defaultdict(set)
156+
self._outgoing: dict[MigrationNodeKey, set[MigrationNode]] = defaultdict(set)
157157

158158
def register_workflow_task(self, task: jobs.Task, job: jobs.Job, _graph: DependencyGraph) -> MigrationNode:
159159
task_id = f"{job.job_id}/{task.task_key}"
@@ -170,13 +170,13 @@ def register_workflow_task(self, task: jobs.Task, job: jobs.Job, _graph: Depende
170170
object_owner=job_node.object_owner, # no task owner so use job one
171171
)
172172
self._nodes[task_node.key] = task_node
173-
self._outgoing[task_node.key].add(job_node.key)
173+
self._outgoing[task_node.key].add(job_node)
174174
if task.existing_cluster_id:
175175
cluster_node = self.register_cluster(task.existing_cluster_id)
176176
if cluster_node:
177-
self._outgoing[task_node.key].add(cluster_node.key)
177+
self._outgoing[task_node.key].add(cluster_node)
178178
# also make the cluster dependent on the job
179-
self._outgoing[job_node.key].add(cluster_node.key)
179+
self._outgoing[job_node.key].add(cluster_node)
180180
# TODO register dependency graph
181181
return task_node
182182

@@ -198,7 +198,7 @@ def register_workflow_job(self, job: jobs.Job) -> MigrationNode:
198198
for job_cluster in job.settings.job_clusters:
199199
cluster_node = self.register_job_cluster(job_cluster)
200200
if cluster_node:
201-
self._outgoing[job_node.key].add(cluster_node.key)
201+
self._outgoing[job_node.key].add(cluster_node)
202202
return job_node
203203

204204
def register_job_cluster(self, cluster: jobs.JobCluster) -> MigrationNode | None:
@@ -234,40 +234,35 @@ def generate_steps(self) -> Iterable[MigrationStep]:
234234
- We handle cyclic dependencies (implemented in PR #3009)
235235
"""
236236
# pre-compute incoming keys for best performance of self._required_step_ids
237-
incoming_keys = self._collect_incoming_keys()
238-
incoming_counts = self.compute_incoming_counts(incoming_keys)
237+
incoming = self._invert_outgoing_to_incoming()
238+
incoming_counts = self._compute_incoming_counts(incoming)
239239
key_queue = self._create_key_queue(incoming_counts)
240240
key = key_queue.get()
241241
step_number = 1
242242
sorted_steps: list[MigrationStep] = []
243243
while key is not None:
244-
required_step_ids = sorted(self._get_required_step_ids(incoming_keys[key]))
245-
step = self._nodes[key].as_step(step_number, required_step_ids)
244+
step = self._nodes[key].as_step(step_number, sorted(n.node_id for n in incoming[key]))
246245
sorted_steps.append(step)
247246
# Update queue priorities
248-
for dependency_key in self._outgoing[key]:
249-
incoming_counts[dependency_key] -= 1
250-
key_queue.update(incoming_counts[dependency_key], dependency_key)
247+
for dependency in self._outgoing[key]:
248+
incoming_counts[dependency.key] -= 1
249+
key_queue.update(incoming_counts[dependency.key], dependency)
251250
step_number += 1
252251
key = key_queue.get()
253252
return sorted_steps
254253

255-
def _collect_incoming_keys(self) -> dict[tuple[str, str], set[tuple[str, str]]]:
256-
result: dict[tuple[str, str], set[tuple[str, str]]] = defaultdict(set)
257-
for source, outgoing in self._outgoing.items():
258-
for target in outgoing:
259-
result[target].add(source)
254+
def _invert_outgoing_to_incoming(self) -> dict[MigrationNodeKey, set[MigrationNode]]:
255+
result: dict[MigrationNodeKey, set[MigrationNode]] = defaultdict(set)
256+
for node_key, outgoing_nodes in self._outgoing.items():
257+
for target in outgoing_nodes:
258+
result[target.key].add(self._nodes[node_key])
260259
return result
261260

262-
def _get_required_step_ids(self, required_step_keys: set[tuple[str, str]]) -> Iterable[int]:
263-
for source_key in required_step_keys:
264-
yield self._nodes[source_key].node_id
265-
266-
def compute_incoming_counts(
267-
self, incoming: dict[tuple[str, str], set[tuple[str, str]]]
268-
) -> dict[tuple[str, str], int]:
261+
def _compute_incoming_counts(
262+
self, incoming: dict[MigrationNodeKey, set[MigrationNode]]
263+
) -> dict[MigrationNodeKey, int]:
269264
result = defaultdict(int)
270-
for node_key in self._nodes:
265+
for node_key in self._nodes.keys():
271266
result[node_key] = len(incoming[node_key])
272267
return result
273268

0 commit comments

Comments
 (0)