Skip to content

Commit da0330d

Browse files
committed
make 'incoming' transient and improve comments
1 parent 17a33e3 commit da0330d

File tree

1 file changed

+27
-16
lines changed

1 file changed

+27
-16
lines changed

src/databricks/labs/ucx/assessment/sequencing.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def __init__(self, ws: WorkspaceClient, admin_locator: AdministratorLocator):
5555
self._admin_locator = admin_locator
5656
self._last_node_id = 0
5757
self._nodes: dict[tuple[str, str], MigrationNode] = {}
58-
self._incoming: dict[tuple[str, str], set[tuple[str, str]]] = defaultdict(set)
5958
self._outgoing: dict[tuple[str, str], set[tuple[str, str]]] = defaultdict(set)
6059

6160
def register_workflow_task(self, task: jobs.Task, job: jobs.Job, _graph: DependencyGraph) -> MigrationNode:
@@ -73,15 +72,12 @@ def register_workflow_task(self, task: jobs.Task, job: jobs.Job, _graph: Depende
7372
object_owner=job_node.object_owner, # no task owner so use job one
7473
)
7574
self._nodes[task_node.key] = task_node
76-
self._incoming[job_node.key].add(task_node.key)
7775
self._outgoing[task_node.key].add(job_node.key)
7876
if task.existing_cluster_id:
7977
cluster_node = self.register_cluster(task.existing_cluster_id)
8078
if cluster_node:
81-
self._incoming[cluster_node.key].add(task_node.key)
8279
self._outgoing[task_node.key].add(cluster_node.key)
8380
# also make the cluster dependent on the job
84-
self._incoming[cluster_node.key].add(job_node.key)
8581
self._outgoing[job_node.key].add(cluster_node.key)
8682
# TODO register dependency graph
8783
return task_node
@@ -104,7 +100,6 @@ def register_workflow_job(self, job: jobs.Job) -> MigrationNode:
104100
for job_cluster in job.settings.job_clusters:
105101
cluster_node = self.register_job_cluster(job_cluster)
106102
if cluster_node:
107-
self._incoming[cluster_node.key].add(job_node.key)
108103
self._outgoing[job_node.key].add(cluster_node.key)
109104
return job_node
110105

@@ -132,31 +127,47 @@ def register_cluster(self, cluster_id: str) -> MigrationNode:
132127
return cluster_node
133128

134129
def generate_steps(self) -> Iterable[MigrationStep]:
135-
# algo adapted from Kahn topological sort. The main differences is that
136-
# we want the same step number for all nodes with same dependency depth
137-
# so instead of pushing to a queue, we rebuild it once all leaf nodes are processed
138-
# (these are transient leaf nodes i.e. they only become leaf during processing)
139-
incoming_counts = self._populate_incoming_counts()
130+
"""algo adapted from Kahn topological sort. The differences are as follows:
131+
- we want the same step number for all nodes with same dependency depth
132+
so instead of pushing to a queue, we rebuild it once all leaf nodes are processed
133+
(these are transient leaf nodes i.e. they only become leaf during processing)
134+
- the inputs do not form a DAG so we need specialized handling of edge cases
135+
(implemented in PR #3009)
136+
"""
137+
# pre-compute incoming keys for best performance of self._required_step_ids
138+
incoming_keys = self._collect_incoming_keys()
139+
incoming_counts = self.compute_incoming_counts(incoming_keys)
140140
step_number = 1
141141
sorted_steps: list[MigrationStep] = []
142142
while len(incoming_counts) > 0:
143143
leaf_keys = list(self._get_leaf_keys(incoming_counts))
144144
for leaf_key in leaf_keys:
145145
del incoming_counts[leaf_key]
146-
sorted_steps.append(self._nodes[leaf_key].as_step(step_number, list(self._required_step_ids(leaf_key))))
146+
sorted_steps.append(
147+
self._nodes[leaf_key].as_step(step_number, list(self._required_step_ids(incoming_keys[leaf_key])))
148+
)
147149
for dependency_key in self._outgoing[leaf_key]:
148150
incoming_counts[dependency_key] -= 1
149151
step_number += 1
150152
return sorted_steps
151153

152-
def _required_step_ids(self, node_key: tuple[str, str]) -> Iterable[int]:
153-
for leaf_key in self._incoming[node_key]:
154-
yield self._nodes[leaf_key].node_id
154+
def _collect_incoming_keys(self) -> dict[tuple[str, str], set[tuple[str, str]]]:
155+
result: dict[tuple[str, str], set[tuple[str, str]]] = defaultdict(set)
156+
for source, outgoing in self._outgoing.items():
157+
for target in outgoing:
158+
result[target].add(source)
159+
return result
160+
161+
def _required_step_ids(self, required_step_keys: set[tuple[str, str]]) -> Iterable[int]:
162+
for source_key in required_step_keys:
163+
yield self._nodes[source_key].node_id
155164

156-
def _populate_incoming_counts(self) -> dict[tuple[str, str], int]:
165+
def compute_incoming_counts(
166+
self, incoming: dict[tuple[str, str], set[tuple[str, str]]]
167+
) -> dict[tuple[str, str], int]:
157168
result = defaultdict(int)
158169
for node_key in self._nodes:
159-
result[node_key] = len(self._incoming[node_key])
170+
result[node_key] = len(incoming[node_key])
160171
return result
161172

162173
@staticmethod

0 commit comments

Comments
 (0)