Skip to content

Commit b13c4d9

Browse files
nkvuongnfx
andauthored
Refactor WorkspaceInstaller using service factory (#1480)
## Changes - refactor `WorkspaceInstaller` using service factory. - this removes the need for `sql_backend_factory` and `wheel_builder_factory` - there is no more logic inside of `main` function in `install.py`, so we can achieve better test coverage through unit testing - remove obsolete reference to `WorkspaceInstaller` in `unit/test_dashboard.py` ### Linked issues <!-- DOC: Link issue with a keyword: close, closes, closed, fix, fixes, fixed, resolve, resolves, resolved. See https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue#linking-a-pull-request-to-an-issue-using-a-keyword --> This follows #1209 ### Tests <!-- How is this tested? Please see the checklist below and also describe any other relevant tests --> - [x] manually tested - [x] added unit tests - [x] verified on staging environment (screenshot attached) --------- Co-authored-by: Serge Smertin <259697+nfx@users.noreply.github.com>
1 parent 58c44df commit b13c4d9

File tree

5 files changed

+222
-149
lines changed

5 files changed

+222
-149
lines changed

src/databricks/labs/ucx/install.py

Lines changed: 79 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import re
66
import time
77
import webbrowser
8-
from collections.abc import Callable
98
from datetime import timedelta
9+
from functools import cached_property
1010
from typing import Any
1111

1212
import databricks.sdk.errors
@@ -19,7 +19,6 @@
1919
from databricks.labs.blueprint.wheels import (
2020
ProductInfo,
2121
Version,
22-
WheelsV2,
2322
find_project_root,
2423
)
2524
from databricks.labs.lsql.backends import SqlBackend, StatementExecutionBackend
@@ -40,7 +39,7 @@
4039
from databricks.labs.ucx.assessment.jobs import JobInfo, SubmitRunInfo
4140
from databricks.labs.ucx.assessment.pipelines import PipelineInfo
4241
from databricks.labs.ucx.config import WorkspaceConfig
43-
from databricks.labs.ucx.contexts.cli_command import AccountContext
42+
from databricks.labs.ucx.contexts.cli_command import AccountContext, WorkspaceContext
4443
from databricks.labs.ucx.framework.dashboards import DashboardFromFiles
4544
from databricks.labs.ucx.framework.tasks import Task
4645
from databricks.labs.ucx.hive_metastore.grants import Grant
@@ -112,68 +111,72 @@ def extract_major_minor(version_string):
112111
return None
113112

114113

115-
class WorkspaceInstaller:
114+
class WorkspaceInstaller(WorkspaceContext):
115+
116116
def __init__(
117117
self,
118-
prompts: Prompts,
119-
installation: Installation,
120118
ws: WorkspaceClient,
121-
product_info: ProductInfo,
122119
environ: dict[str, str] | None = None,
123120
tasks: list[Task] | None = None,
124121
):
122+
super().__init__(ws)
125123
if not environ:
126124
environ = dict(os.environ.items())
125+
self._force_install = environ.get("UCX_FORCE_INSTALL")
127126
if "DATABRICKS_RUNTIME_VERSION" in environ:
128127
msg = "WorkspaceInstaller is not supposed to be executed in Databricks Runtime"
129128
raise SystemExit(msg)
130-
self._ws = ws
131-
self._installation = installation
132-
self._prompts = prompts
133-
self._policy_installer = ClusterPolicyInstaller(installation, ws, prompts)
134-
self._product_info = product_info
135-
self._force_install = environ.get("UCX_FORCE_INSTALL")
136-
self._is_account_install = environ.get("UCX_FORCE_INSTALL") == "account"
129+
130+
self._is_account_install = self._force_install == "account"
137131
self._tasks = tasks if tasks else Workflows.all().tasks()
138132

133+
@cached_property
134+
def upgrades(self):
135+
return Upgrades(self.product_info, self.installation)
136+
137+
@cached_property
138+
def policy_installer(self):
139+
return ClusterPolicyInstaller(self.installation, self.workspace_client, self.prompts)
140+
141+
@cached_property
142+
def installation(self):
143+
try:
144+
return self.product_info.current_installation(self.workspace_client)
145+
except NotFound:
146+
if self._force_install == "user":
147+
return Installation.assume_user_home(self.workspace_client, self.product_info.product_name())
148+
return Installation.assume_global(self.workspace_client, self.product_info.product_name())
149+
139150
def run(
140151
self,
141152
default_config: WorkspaceConfig | None = None,
142153
verify_timeout=timedelta(minutes=2),
143-
sql_backend_factory: Callable[[WorkspaceConfig], SqlBackend] | None = None,
144-
wheel_builder_factory: Callable[[], WheelsV2] | None = None,
145154
config: WorkspaceConfig | None = None,
146155
) -> WorkspaceConfig:
147-
logger.info(f"Installing UCX v{self._product_info.version()}")
156+
logger.info(f"Installing UCX v{self.product_info.version()}")
148157
if config is None:
149158
config = self.configure(default_config)
150-
if not sql_backend_factory:
151-
sql_backend_factory = self._new_sql_backend
152-
if not wheel_builder_factory:
153-
wheel_builder_factory = self._new_wheel_builder
154-
wheels = wheel_builder_factory()
155-
install_state = InstallState.from_installation(self._installation)
156159
if self._is_testing():
157160
return config
158161
workflows_deployment = WorkflowsDeployment(
159162
config,
160-
self._installation,
161-
install_state,
162-
self._ws,
163-
wheels,
164-
self._product_info,
163+
self.installation,
164+
self.install_state,
165+
self.workspace_client,
166+
self.wheels,
167+
self.product_info,
165168
verify_timeout,
166169
self._tasks,
167170
)
168171
workspace_installation = WorkspaceInstallation(
169172
config,
170-
self._installation,
171-
install_state,
172-
sql_backend_factory(config),
173-
self._ws,
173+
self.installation,
174+
self.install_state,
175+
self.sql_backend,
176+
self.workspace_client,
174177
workflows_deployment,
175-
self._prompts,
176-
self._product_info,
178+
self.prompts,
179+
self.product_info,
177180
)
178181
try:
179182
workspace_installation.run()
@@ -184,21 +187,21 @@ def run(
184187
return config
185188

186189
def _is_testing(self):
187-
return self._product_info.product_name() != "ucx"
190+
return self.product_info.product_name() != "ucx"
188191

189192
def _prompt_for_new_installation(self) -> WorkspaceConfig:
190193
logger.info("Please answer a couple of questions to configure Unity Catalog migration")
191-
inventory_database = self._prompts.question(
194+
inventory_database = self.prompts.question(
192195
"Inventory Database stored in hive_metastore", default="ucx", valid_regex=r"^\w+$"
193196
)
194-
log_level = self._prompts.question("Log level", default="INFO").upper()
195-
num_threads = int(self._prompts.question("Number of threads", default="8", valid_number=True))
196-
configure_groups = ConfigureGroups(self._prompts)
197+
log_level = self.prompts.question("Log level", default="INFO").upper()
198+
num_threads = int(self.prompts.question("Number of threads", default="8", valid_number=True))
199+
configure_groups = ConfigureGroups(self.prompts)
197200
configure_groups.run()
198201
# Check if terraform is being used
199-
is_terraform_used = self._prompts.confirm("Do you use Terraform to deploy your infrastructure?")
202+
is_terraform_used = self.prompts.confirm("Do you use Terraform to deploy your infrastructure?")
200203
include_databases = self._select_databases()
201-
trigger_job = self._prompts.confirm("Do you want to trigger assessment job after installation?")
204+
trigger_job = self.prompts.confirm("Do you want to trigger assessment job after installation?")
202205
return WorkspaceConfig(
203206
inventory_database=inventory_database,
204207
workspace_group_regex=configure_groups.workspace_group_regex,
@@ -216,45 +219,41 @@ def _prompt_for_new_installation(self) -> WorkspaceConfig:
216219

217220
def _compare_remote_local_versions(self):
218221
try:
219-
local_version = self._product_info.released_version()
220-
remote_version = self._installation.load(Version).version
222+
local_version = self.product_info.released_version()
223+
remote_version = self.installation.load(Version).version
221224
if extract_major_minor(remote_version) == extract_major_minor(local_version):
222-
logger.info(f"UCX v{self._product_info.version()} is already installed on this workspace")
225+
logger.info(f"UCX v{self.product_info.version()} is already installed on this workspace")
223226
msg = "Do you want to update the existing installation?"
224-
if not self._prompts.confirm(msg):
227+
if not self.prompts.confirm(msg):
225228
raise RuntimeWarning(
226229
"UCX workspace remote and local install versions are same and no override is requested. Exiting..."
227230
)
228231
except NotFound as err:
229232
logger.warning(f"UCX workspace remote version not found: {err}")
230233

231-
def _new_wheel_builder(self):
232-
return WheelsV2(self._installation, self._product_info)
233-
234-
def _new_sql_backend(self, config: WorkspaceConfig) -> SqlBackend:
235-
return StatementExecutionBackend(self._ws, config.warehouse_id)
236-
237234
def _confirm_force_install(self) -> bool:
238235
if not self._force_install:
239236
return False
240237
msg = "[ADVANCED] UCX is already installed on this workspace. Do you want to create a new installation?"
241-
if not self._prompts.confirm(msg):
238+
if not self.prompts.confirm(msg):
242239
raise RuntimeWarning("UCX is already installed, but no confirmation")
243-
if not self._installation.is_global() and self._force_install == "global":
240+
if not self.installation.is_global() and self._force_install == "global":
244241
# TODO:
245242
# Logic for forced global over user install
246243
# Migration logic will go here
247244
# verify complains without full path, asks to raise NotImplementedError builtin
248245
raise databricks.sdk.errors.NotImplemented("Migration needed. Not implemented yet.")
249-
if self._installation.is_global() and self._force_install == "user":
246+
if self.installation.is_global() and self._force_install == "user":
250247
# Logic for forced user install over global install
251-
self._installation = Installation.assume_user_home(self._ws, self._product_info.product_name())
248+
self.replace(
249+
installation=Installation.assume_user_home(self.workspace_client, self.product_info.product_name())
250+
)
252251
return True
253252
return False
254253

255254
def configure(self, default_config: WorkspaceConfig | None = None) -> WorkspaceConfig:
256255
try:
257-
config = self._installation.load(WorkspaceConfig)
256+
config = self.installation.load(WorkspaceConfig)
258257
self._compare_remote_local_versions()
259258
if self._confirm_force_install():
260259
return self._configure_new_installation(default_config)
@@ -269,27 +268,25 @@ def replace_config(self, **changes: Any):
269268
Persist the list of workspaces where UCX is successfully installed in the config
270269
"""
271270
try:
272-
config = self._installation.load(WorkspaceConfig)
271+
config = self.installation.load(WorkspaceConfig)
273272
new_config = dataclasses.replace(config, **changes)
274-
self._installation.save(new_config)
273+
self.installation.save(new_config)
275274
except (PermissionDenied, NotFound, ValueError):
276-
logger.warning(f"Failed to replace config for {self._ws.config.host}")
275+
logger.warning(f"Failed to replace config for {self.workspace_client.config.host}")
277276

278277
def _apply_upgrades(self):
279278
try:
280-
upgrades = Upgrades(self._product_info, self._installation)
281-
upgrades.apply(self._ws)
279+
self.upgrades.apply(self.workspace_client)
282280
except (InvalidParameterValue, NotFound) as err:
283281
logger.warning(f"Installed version is too old: {err}")
284282

285283
def _configure_new_installation(self, default_config: WorkspaceConfig | None = None) -> WorkspaceConfig:
286284
if default_config is None:
287285
default_config = self._prompt_for_new_installation()
288-
HiveMetastoreLineageEnabler(self._ws).apply(self._prompts, self._is_account_install)
286+
HiveMetastoreLineageEnabler(self.workspace_client).apply(self.prompts, self._is_account_install)
289287
self._check_inventory_database_exists(default_config.inventory_database)
290288
warehouse_id = self._configure_warehouse()
291-
292-
policy_id, instance_profile, spark_conf_dict, instance_pool_id = self._policy_installer.create(
289+
policy_id, instance_profile, spark_conf_dict, instance_pool_id = self.policy_installer.create(
293290
default_config.inventory_database
294291
)
295292

@@ -306,38 +303,38 @@ def _configure_new_installation(self, default_config: WorkspaceConfig | None = N
306303
policy_id=policy_id,
307304
instance_pool_id=instance_pool_id,
308305
)
309-
self._installation.save(config)
306+
self.installation.save(config)
310307
if self._is_account_install:
311308
return config
312-
ws_file_url = self._installation.workspace_link(config.__file__)
313-
if self._prompts.confirm(f"Open config file in the browser and continue installing? {ws_file_url}"):
309+
ws_file_url = self.installation.workspace_link(config.__file__)
310+
if self.prompts.confirm(f"Open config file in the browser and continue installing? {ws_file_url}"):
314311
webbrowser.open(ws_file_url)
315312
return config
316313

317314
def _config_table_migration(self, spark_conf_dict) -> tuple[int, int, dict]:
318315
# parallelism will not be needed if backlog is fixed in https://databricks.atlassian.net/browse/ES-975874
319316
if self._is_account_install:
320317
return 1, 10, spark_conf_dict
321-
parallelism = self._prompts.question(
318+
parallelism = self.prompts.question(
322319
"Parallelism for migrating dbfs root delta tables with deep clone", default="200", valid_number=True
323320
)
324321
if int(parallelism) > 200:
325322
spark_conf_dict.update({'spark.sql.sources.parallelPartitionDiscovery.parallelism': parallelism})
326323
# mix max workers for auto-scale migration job cluster
327324
min_workers = int(
328-
self._prompts.question(
325+
self.prompts.question(
329326
"Min workers for auto-scale job cluster for table migration", default="1", valid_number=True
330327
)
331328
)
332329
max_workers = int(
333-
self._prompts.question(
330+
self.prompts.question(
334331
"Max workers for auto-scale job cluster for table migration", default="10", valid_number=True
335332
)
336333
)
337334
return min_workers, max_workers, spark_conf_dict
338335

339336
def _select_databases(self):
340-
selected_databases = self._prompts.question(
337+
selected_databases = self.prompts.question(
341338
"Comma-separated list of databases to migrate. If not specified, we'll use all "
342339
"databases in hive_metastore",
343340
default="<ALL>",
@@ -352,17 +349,17 @@ def warehouse_type(_):
352349

353350
pro_warehouses = {"[Create new PRO SQL warehouse]": "create_new"} | {
354351
f"{_.name} ({_.id}, {warehouse_type(_)}, {_.state.value})": _.id
355-
for _ in self._ws.warehouses.list()
352+
for _ in self.workspace_client.warehouses.list()
356353
if _.warehouse_type == EndpointInfoWarehouseType.PRO
357354
}
358355
if self._is_account_install:
359356
warehouse_id = "create_new"
360357
else:
361-
warehouse_id = self._prompts.choice_from_dict(
358+
warehouse_id = self.prompts.choice_from_dict(
362359
"Select PRO or SERVERLESS SQL warehouse to run assessment dashboards on", pro_warehouses
363360
)
364361
if warehouse_id == "create_new":
365-
new_warehouse = self._ws.warehouses.create(
362+
new_warehouse = self.workspace_client.warehouses.create(
366363
name=f"{WAREHOUSE_PREFIX} {time.time_ns()}",
367364
spot_instance_policy=SpotInstancePolicy.COST_OPTIMIZED,
368365
warehouse_type=CreateWarehouseRequestWarehouseType.PRO,
@@ -374,7 +371,7 @@ def warehouse_type(_):
374371

375372
def _check_inventory_database_exists(self, inventory_database: str):
376373
logger.info("Fetching installations...")
377-
for installation in Installation.existing(self._ws, self._product_info.product_name()):
374+
for installation in Installation.existing(self.workspace_client, self.product_info.product_name()):
378375
try:
379376
config = installation.load(WorkspaceConfig)
380377
if config.inventory_database == inventory_database:
@@ -614,19 +611,13 @@ def _get_accessible_workspaces(self):
614611
accessible_workspaces.append(workspace)
615612
return accessible_workspaces
616613

617-
def _get_installer(self, app: ProductInfo, workspace: Workspace) -> WorkspaceInstaller:
614+
def _get_installer(self, workspace: Workspace) -> WorkspaceInstaller:
618615
workspace_client = self.account_client.get_workspace_client(workspace)
619616
logger.info(f"Installing UCX on workspace {workspace.deployment_name}")
620-
try:
621-
current = app.current_installation(workspace_client)
622-
except NotFound:
623-
current = Installation.assume_global(workspace_client, app.product_name())
624-
return WorkspaceInstaller(self.prompts, current, workspace_client, app)
617+
return WorkspaceInstaller(workspace_client).replace(product_info=self.product_info, prompts=self.prompts)
625618

626-
def install_on_account(self, app: ProductInfo | None = None):
619+
def install_on_account(self):
627620
ctx = AccountContext(self._get_safe_account_client())
628-
if app is None:
629-
app = ProductInfo.from_class(WorkspaceConfig)
630621
default_config = None
631622
confirmed = False
632623
accessible_workspaces = self._get_accessible_workspaces()
@@ -639,7 +630,7 @@ def install_on_account(self, app: ProductInfo | None = None):
639630

640631
for workspace in accessible_workspaces:
641632
logger.info(f"Installing UCX on workspace {workspace.deployment_name}")
642-
installer = self._get_installer(app, workspace)
633+
installer = self._get_installer(workspace)
643634
if not confirmed:
644635
default_config = None
645636
try:
@@ -656,26 +647,13 @@ def install_on_account(self, app: ProductInfo | None = None):
656647

657648
installed_workspace_ids = [w.workspace_id for w in installed_workspaces if w.workspace_id is not None]
658649
for workspace in installed_workspaces:
659-
installer = self._get_installer(app, workspace)
650+
installer = self._get_installer(workspace)
660651
installer.replace_config(installed_workspace_ids=installed_workspace_ids)
661652

662653
# upload the json dump of workspace info in the .ucx folder
663654
ctx.account_workspaces.sync_workspace_info(installed_workspaces)
664655

665656

666-
def install_on_workspace(app: ProductInfo | None = None):
667-
if app is None:
668-
app = ProductInfo.from_class(WorkspaceConfig)
669-
prompts = Prompts()
670-
workspace_client = WorkspaceClient(product="ucx", product_version=__version__)
671-
try:
672-
current = app.current_installation(workspace_client)
673-
except NotFound:
674-
current = Installation.assume_global(workspace_client, app.product_name())
675-
installer = WorkspaceInstaller(prompts, current, workspace_client, app)
676-
installer.run()
677-
678-
679657
if __name__ == "__main__":
680658
logger = get_logger(__file__)
681659

@@ -685,4 +663,5 @@ def install_on_workspace(app: ProductInfo | None = None):
685663
account_installer = AccountInstaller(AccountClient(product="ucx", product_version=__version__))
686664
account_installer.install_on_account()
687665
else:
688-
install_on_workspace()
666+
workspace_installer = WorkspaceInstaller(WorkspaceClient(product="ucx", product_version=__version__))
667+
workspace_installer.run()

0 commit comments

Comments
 (0)