1
1
from __future__ import annotations
2
2
3
- import itertools
3
+ from collections import defaultdict
4
4
from collections .abc import Iterable
5
- from dataclasses import dataclass , field
5
+ from dataclasses import dataclass
6
6
7
7
from databricks .sdk import WorkspaceClient
8
8
from databricks .sdk .service import jobs
@@ -18,7 +18,7 @@ class MigrationStep:
18
18
object_id : str
19
19
object_name : str
20
20
object_owner : str
21
- required_step_ids : list [int ] = field ( default_factory = list )
21
+ required_step_ids : list [int ]
22
22
23
23
24
24
@dataclass
@@ -28,54 +28,35 @@ class MigrationNode:
28
28
object_id : str
29
29
object_name : str
30
30
object_owner : str
31
- required_steps : list [MigrationNode ] = field (default_factory = list )
32
-
33
- def generate_steps (self ) -> tuple [MigrationStep , Iterable [MigrationStep ]]:
34
- # traverse the nodes using a depth-first algorithm
35
- # ultimate leaves have a step number of 1
36
- # use highest required step number + 1 for this step
37
- highest_step_number = 0
38
- required_step_ids : list [int ] = []
39
- all_generated_steps : list [Iterable [MigrationStep ]] = []
40
- for required_step in self .required_steps :
41
- step , generated_steps = required_step .generate_steps ()
42
- highest_step_number = max (highest_step_number , step .step_number )
43
- required_step_ids .append (step .step_id )
44
- all_generated_steps .append (generated_steps )
45
- all_generated_steps .append ([step ])
46
- this_step = MigrationStep (
31
+
32
+ @property
33
+ def key (self ) -> tuple [str , str ]:
34
+ return self .object_type , self .object_id
35
+
36
+ def as_step (self , step_number : int , required_step_ids : list [int ]) -> MigrationStep :
37
+ return MigrationStep (
47
38
step_id = self .node_id ,
48
- step_number = highest_step_number + 1 ,
39
+ step_number = step_number ,
49
40
object_type = self .object_type ,
50
41
object_id = self .object_id ,
51
42
object_name = self .object_name ,
52
43
object_owner = self .object_owner ,
53
44
required_step_ids = required_step_ids ,
54
45
)
55
- return this_step , itertools .chain (* all_generated_steps )
56
-
57
- def find (self , object_type : str , object_id : str ) -> MigrationNode | None :
58
- if object_type == self .object_type and object_id == self .object_id :
59
- return self
60
- for step in self .required_steps :
61
- found = step .find (object_type , object_id )
62
- if found :
63
- return found
64
- return None
65
46
66
47
67
48
class MigrationSequencer :
68
49
69
50
def __init__ (self , ws : WorkspaceClient ):
70
51
self ._ws = ws
71
52
self ._last_node_id = 0
72
- self ._root = MigrationNode (
73
- node_id = 0 , object_type = "ROOT" , object_id = "ROOT" , object_name = "ROOT" , object_owner = "NONE"
74
- )
53
+ self ._nodes : dict [ tuple [ str , str ], MigrationNode ] = {}
54
+ self . _incoming : dict [ tuple [ str , str ], set [ tuple [ str , str ]]] = defaultdict ( set )
55
+ self . _outgoing : dict [ tuple [ str , str ], set [ tuple [ str , str ]]] = defaultdict ( set )
75
56
76
57
def register_workflow_task (self , task : jobs .Task , job : jobs .Job , _graph : DependencyGraph ) -> MigrationNode :
77
58
task_id = f"{ job .job_id } /{ task .task_key } "
78
- task_node = self ._find_node ( object_type = "TASK" , object_id = task_id )
59
+ task_node = self ._nodes . get (( "TASK" , task_id ), None )
79
60
if task_node :
80
61
return task_node
81
62
job_node = self .register_workflow_job (job )
@@ -87,17 +68,22 @@ def register_workflow_task(self, task: jobs.Task, job: jobs.Job, _graph: Depende
87
68
object_name = task .task_key ,
88
69
object_owner = job_node .object_owner , # no task owner so use job one
89
70
)
90
- job_node .required_steps .append (task_node )
71
+ self ._nodes [task_node .key ] = task_node
72
+ self ._incoming [job_node .key ].add (task_node .key )
73
+ self ._outgoing [task_node .key ].add (job_node .key )
91
74
if task .existing_cluster_id :
92
75
cluster_node = self .register_cluster (task .existing_cluster_id )
93
- cluster_node .required_steps .append (task_node )
94
- if job_node not in cluster_node .required_steps :
95
- cluster_node .required_steps .append (job_node )
76
+ if cluster_node :
77
+ self ._incoming [cluster_node .key ].add (task_node .key )
78
+ self ._outgoing [task_node .key ].add (cluster_node .key )
79
+ # also make the cluster dependent on the job
80
+ self ._incoming [cluster_node .key ].add (job_node .key )
81
+ self ._outgoing [job_node .key ].add (cluster_node .key )
96
82
# TODO register dependency graph
97
83
return task_node
98
84
99
85
def register_workflow_job (self , job : jobs .Job ) -> MigrationNode :
100
- job_node = self ._find_node ( object_type = "JOB" , object_id = str (job .job_id ))
86
+ job_node = self ._nodes . get (( "JOB" , str (job .job_id )), None )
101
87
if job_node :
102
88
return job_node
103
89
self ._last_node_id += 1
@@ -109,63 +95,69 @@ def register_workflow_job(self, job: jobs.Job) -> MigrationNode:
109
95
object_name = job_name ,
110
96
object_owner = job .creator_user_name or "<UNKNOWN>" ,
111
97
)
112
- top_level = True
98
+ self . _nodes [ job_node . key ] = job_node
113
99
if job .settings and job .settings .job_clusters :
114
100
for job_cluster in job .settings .job_clusters :
115
101
cluster_node = self .register_job_cluster (job_cluster )
116
102
if cluster_node :
117
- top_level = False
118
- cluster_node .required_steps .append (job_node )
119
- if top_level :
120
- self ._root .required_steps .append (job_node )
103
+ self ._incoming [cluster_node .key ].add (job_node .key )
104
+ self ._outgoing [job_node .key ].add (cluster_node .key )
121
105
return job_node
122
106
123
107
def register_job_cluster (self , cluster : jobs .JobCluster ) -> MigrationNode | None :
124
108
if cluster .new_cluster :
125
109
return None
126
110
return self .register_cluster (cluster .job_cluster_key )
127
111
128
- def register_cluster (self , cluster_key : str ) -> MigrationNode :
129
- cluster_node = self ._find_node ( object_type = "CLUSTER" , object_id = cluster_key )
112
+ def register_cluster (self , cluster_id : str ) -> MigrationNode :
113
+ cluster_node = self ._nodes . get (( "CLUSTER" , cluster_id ), None )
130
114
if cluster_node :
131
115
return cluster_node
132
- details = self ._ws .clusters .get (cluster_key )
133
- object_name = details .cluster_name if details and details .cluster_name else cluster_key
134
- object_owner = details .creator_user_name if details and details .creator_user_name else "<UNKNOWN>"
116
+ details = self ._ws .clusters .get (cluster_id )
117
+ object_name = details .cluster_name if details and details .cluster_name else cluster_id
135
118
self ._last_node_id += 1
136
119
cluster_node = MigrationNode (
137
120
node_id = self ._last_node_id ,
138
121
object_type = "CLUSTER" ,
139
- object_id = cluster_key ,
122
+ object_id = cluster_id ,
140
123
object_name = object_name ,
141
124
object_owner = object_owner ,
142
125
)
126
+ self ._nodes [cluster_node .key ] = cluster_node
143
127
# TODO register warehouses and policies
144
- self ._root .required_steps .append (cluster_node )
145
128
return cluster_node
146
129
147
130
def generate_steps (self ) -> Iterable [MigrationStep ]:
148
- _root_step , generated_steps = self ._root .generate_steps ()
149
- unique_steps = self ._deduplicate_steps (generated_steps )
150
- return self ._sorted_steps (unique_steps )
131
+ # algo adapted from Kahn topological sort. The main differences is that
132
+ # we want the same step number for all nodes with same dependency depth
133
+ # so instead of pushing to a queue, we rebuild it once all leaf nodes are processed
134
+ # (these are transient leaf nodes i.e. they only become leaf during processing)
135
+ incoming_counts = self ._populate_incoming_counts ()
136
+ step_number = 1
137
+ sorted_steps : list [MigrationStep ] = []
138
+ while len (incoming_counts ) > 0 :
139
+ leaf_keys = list (self ._get_leaf_keys (incoming_counts ))
140
+ for leaf_key in leaf_keys :
141
+ del incoming_counts [leaf_key ]
142
+ sorted_steps .append (self ._nodes [leaf_key ].as_step (step_number , list (self ._required_step_ids (leaf_key ))))
143
+ for dependency_key in self ._outgoing [leaf_key ]:
144
+ incoming_counts [dependency_key ] -= 1
145
+ step_number += 1
146
+ return sorted_steps
147
+
148
+ def _required_step_ids (self , node_key : tuple [str , str ]) -> Iterable [int ]:
149
+ for leaf_key in self ._incoming [node_key ]:
150
+ yield self ._nodes [leaf_key ].node_id
151
+
152
+ def _populate_incoming_counts (self ) -> dict [tuple [str , str ], int ]:
153
+ result = defaultdict (int )
154
+ for node_key in self ._nodes :
155
+ result [node_key ] = len (self ._incoming [node_key ])
156
+ return result
151
157
152
158
@staticmethod
153
- def _sorted_steps (steps : Iterable [MigrationStep ]) -> Iterable [MigrationStep ]:
154
- # sort by step number, lowest first
155
- return sorted (steps , key = lambda step : step .step_number )
156
-
157
- @staticmethod
158
- def _deduplicate_steps (steps : Iterable [MigrationStep ]) -> Iterable [MigrationStep ]:
159
- best_steps : dict [int , MigrationStep ] = {}
160
- for step in steps :
161
- existing = best_steps .get (step .step_id , None )
162
- # keep the step with the highest step number
163
- # TODO this possibly affects the step_number of steps that depend on this one
164
- # but it's probably OK to not be 100% accurate initially
165
- if existing and existing .step_number >= step .step_number :
159
+ def _get_leaf_keys (incoming_counts : dict [tuple [str , str ], int ]) -> Iterable [tuple [str , str ]]:
160
+ for node_key , incoming_count in incoming_counts .items ():
161
+ if incoming_count > 0 :
166
162
continue
167
- best_steps [step .step_id ] = step
168
- return best_steps .values ()
169
-
170
- def _find_node (self , object_type : str , object_id : str ) -> MigrationNode | None :
171
- return self ._root .find (object_type , object_id )
163
+ yield node_key
0 commit comments