Skip to content

Commit 9c5d569

Browse files
committed
cherry-pick changes
1 parent f2ce384 commit 9c5d569

File tree

1 file changed

+64
-72
lines changed

1 file changed

+64
-72
lines changed
Lines changed: 64 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

3-
import itertools
3+
from collections import defaultdict
44
from collections.abc import Iterable
5-
from dataclasses import dataclass, field
5+
from dataclasses import dataclass
66

77
from databricks.sdk import WorkspaceClient
88
from databricks.sdk.service import jobs
@@ -18,7 +18,7 @@ class MigrationStep:
1818
object_id: str
1919
object_name: str
2020
object_owner: str
21-
required_step_ids: list[int] = field(default_factory=list)
21+
required_step_ids: list[int]
2222

2323

2424
@dataclass
@@ -28,54 +28,35 @@ class MigrationNode:
2828
object_id: str
2929
object_name: str
3030
object_owner: str
31-
required_steps: list[MigrationNode] = field(default_factory=list)
32-
33-
def generate_steps(self) -> tuple[MigrationStep, Iterable[MigrationStep]]:
34-
# traverse the nodes using a depth-first algorithm
35-
# ultimate leaves have a step number of 1
36-
# use highest required step number + 1 for this step
37-
highest_step_number = 0
38-
required_step_ids: list[int] = []
39-
all_generated_steps: list[Iterable[MigrationStep]] = []
40-
for required_step in self.required_steps:
41-
step, generated_steps = required_step.generate_steps()
42-
highest_step_number = max(highest_step_number, step.step_number)
43-
required_step_ids.append(step.step_id)
44-
all_generated_steps.append(generated_steps)
45-
all_generated_steps.append([step])
46-
this_step = MigrationStep(
31+
32+
@property
33+
def key(self) -> tuple[str, str]:
34+
return self.object_type, self.object_id
35+
36+
def as_step(self, step_number: int, required_step_ids: list[int]) -> MigrationStep:
37+
return MigrationStep(
4738
step_id=self.node_id,
48-
step_number=highest_step_number + 1,
39+
step_number=step_number,
4940
object_type=self.object_type,
5041
object_id=self.object_id,
5142
object_name=self.object_name,
5243
object_owner=self.object_owner,
5344
required_step_ids=required_step_ids,
5445
)
55-
return this_step, itertools.chain(*all_generated_steps)
56-
57-
def find(self, object_type: str, object_id: str) -> MigrationNode | None:
58-
if object_type == self.object_type and object_id == self.object_id:
59-
return self
60-
for step in self.required_steps:
61-
found = step.find(object_type, object_id)
62-
if found:
63-
return found
64-
return None
6546

6647

6748
class MigrationSequencer:
6849

6950
def __init__(self, ws: WorkspaceClient):
7051
self._ws = ws
7152
self._last_node_id = 0
72-
self._root = MigrationNode(
73-
node_id=0, object_type="ROOT", object_id="ROOT", object_name="ROOT", object_owner="NONE"
74-
)
53+
self._nodes: dict[tuple[str, str], MigrationNode] = {}
54+
self._incoming: dict[tuple[str, str], set[tuple[str, str]]] = defaultdict(set)
55+
self._outgoing: dict[tuple[str, str], set[tuple[str, str]]] = defaultdict(set)
7556

7657
def register_workflow_task(self, task: jobs.Task, job: jobs.Job, _graph: DependencyGraph) -> MigrationNode:
7758
task_id = f"{job.job_id}/{task.task_key}"
78-
task_node = self._find_node(object_type="TASK", object_id=task_id)
59+
task_node = self._nodes.get(("TASK", task_id), None)
7960
if task_node:
8061
return task_node
8162
job_node = self.register_workflow_job(job)
@@ -87,17 +68,22 @@ def register_workflow_task(self, task: jobs.Task, job: jobs.Job, _graph: Depende
8768
object_name=task.task_key,
8869
object_owner=job_node.object_owner, # no task owner so use job one
8970
)
90-
job_node.required_steps.append(task_node)
71+
self._nodes[task_node.key] = task_node
72+
self._incoming[job_node.key].add(task_node.key)
73+
self._outgoing[task_node.key].add(job_node.key)
9174
if task.existing_cluster_id:
9275
cluster_node = self.register_cluster(task.existing_cluster_id)
93-
cluster_node.required_steps.append(task_node)
94-
if job_node not in cluster_node.required_steps:
95-
cluster_node.required_steps.append(job_node)
76+
if cluster_node:
77+
self._incoming[cluster_node.key].add(task_node.key)
78+
self._outgoing[task_node.key].add(cluster_node.key)
79+
# also make the cluster dependent on the job
80+
self._incoming[cluster_node.key].add(job_node.key)
81+
self._outgoing[job_node.key].add(cluster_node.key)
9682
# TODO register dependency graph
9783
return task_node
9884

9985
def register_workflow_job(self, job: jobs.Job) -> MigrationNode:
100-
job_node = self._find_node(object_type="JOB", object_id=str(job.job_id))
86+
job_node = self._nodes.get(("JOB", str(job.job_id)), None)
10187
if job_node:
10288
return job_node
10389
self._last_node_id += 1
@@ -109,63 +95,69 @@ def register_workflow_job(self, job: jobs.Job) -> MigrationNode:
10995
object_name=job_name,
11096
object_owner=job.creator_user_name or "<UNKNOWN>",
11197
)
112-
top_level = True
98+
self._nodes[job_node.key] = job_node
11399
if job.settings and job.settings.job_clusters:
114100
for job_cluster in job.settings.job_clusters:
115101
cluster_node = self.register_job_cluster(job_cluster)
116102
if cluster_node:
117-
top_level = False
118-
cluster_node.required_steps.append(job_node)
119-
if top_level:
120-
self._root.required_steps.append(job_node)
103+
self._incoming[cluster_node.key].add(job_node.key)
104+
self._outgoing[job_node.key].add(cluster_node.key)
121105
return job_node
122106

123107
def register_job_cluster(self, cluster: jobs.JobCluster) -> MigrationNode | None:
124108
if cluster.new_cluster:
125109
return None
126110
return self.register_cluster(cluster.job_cluster_key)
127111

128-
def register_cluster(self, cluster_key: str) -> MigrationNode:
129-
cluster_node = self._find_node(object_type="CLUSTER", object_id=cluster_key)
112+
def register_cluster(self, cluster_id: str) -> MigrationNode:
113+
cluster_node = self._nodes.get(("CLUSTER", cluster_id), None)
130114
if cluster_node:
131115
return cluster_node
132-
details = self._ws.clusters.get(cluster_key)
133-
object_name = details.cluster_name if details and details.cluster_name else cluster_key
134-
object_owner = details.creator_user_name if details and details.creator_user_name else "<UNKNOWN>"
116+
details = self._ws.clusters.get(cluster_id)
117+
object_name = details.cluster_name if details and details.cluster_name else cluster_id
135118
self._last_node_id += 1
136119
cluster_node = MigrationNode(
137120
node_id=self._last_node_id,
138121
object_type="CLUSTER",
139-
object_id=cluster_key,
122+
object_id=cluster_id,
140123
object_name=object_name,
141124
object_owner=object_owner,
142125
)
126+
self._nodes[cluster_node.key] = cluster_node
143127
# TODO register warehouses and policies
144-
self._root.required_steps.append(cluster_node)
145128
return cluster_node
146129

147130
def generate_steps(self) -> Iterable[MigrationStep]:
148-
_root_step, generated_steps = self._root.generate_steps()
149-
unique_steps = self._deduplicate_steps(generated_steps)
150-
return self._sorted_steps(unique_steps)
131+
# algo adapted from Kahn topological sort. The main differences is that
132+
# we want the same step number for all nodes with same dependency depth
133+
# so instead of pushing to a queue, we rebuild it once all leaf nodes are processed
134+
# (these are transient leaf nodes i.e. they only become leaf during processing)
135+
incoming_counts = self._populate_incoming_counts()
136+
step_number = 1
137+
sorted_steps: list[MigrationStep] = []
138+
while len(incoming_counts) > 0:
139+
leaf_keys = list(self._get_leaf_keys(incoming_counts))
140+
for leaf_key in leaf_keys:
141+
del incoming_counts[leaf_key]
142+
sorted_steps.append(self._nodes[leaf_key].as_step(step_number, list(self._required_step_ids(leaf_key))))
143+
for dependency_key in self._outgoing[leaf_key]:
144+
incoming_counts[dependency_key] -= 1
145+
step_number += 1
146+
return sorted_steps
147+
148+
def _required_step_ids(self, node_key: tuple[str, str]) -> Iterable[int]:
149+
for leaf_key in self._incoming[node_key]:
150+
yield self._nodes[leaf_key].node_id
151+
152+
def _populate_incoming_counts(self) -> dict[tuple[str, str], int]:
153+
result = defaultdict(int)
154+
for node_key in self._nodes:
155+
result[node_key] = len(self._incoming[node_key])
156+
return result
151157

152158
@staticmethod
153-
def _sorted_steps(steps: Iterable[MigrationStep]) -> Iterable[MigrationStep]:
154-
# sort by step number, lowest first
155-
return sorted(steps, key=lambda step: step.step_number)
156-
157-
@staticmethod
158-
def _deduplicate_steps(steps: Iterable[MigrationStep]) -> Iterable[MigrationStep]:
159-
best_steps: dict[int, MigrationStep] = {}
160-
for step in steps:
161-
existing = best_steps.get(step.step_id, None)
162-
# keep the step with the highest step number
163-
# TODO this possibly affects the step_number of steps that depend on this one
164-
# but it's probably OK to not be 100% accurate initially
165-
if existing and existing.step_number >= step.step_number:
159+
def _get_leaf_keys(incoming_counts: dict[tuple[str, str], int]) -> Iterable[tuple[str, str]]:
160+
for node_key, incoming_count in incoming_counts.items():
161+
if incoming_count > 0:
166162
continue
167-
best_steps[step.step_id] = step
168-
return best_steps.values()
169-
170-
def _find_node(self, object_type: str, object_id: str) -> MigrationNode | None:
171-
return self._root.find(object_type, object_id)
163+
yield node_key

0 commit comments

Comments
 (0)