diff --git a/labs.yml b/labs.yml index 82481aadf..0522a38c1 100644 --- a/labs.yml +++ b/labs.yml @@ -63,6 +63,9 @@ commands: - name: interactive description: (Optional) Whether installing in interactive mode (`true|false|auto`); configuration settings are prompted for when interactive default: auto + - name: include-llm-transpiler + description: (Optional) Whether to include LLM-based transpiler in installation (`true|false`) + default: "false" - name: describe-transpile description: Describe installed transpilers diff --git a/pyproject.toml b/pyproject.toml index f61e8abe3..525ab4d38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,9 @@ dependencies = [ "SQLAlchemy~=2.0.40", "pygls~=2.0.0a2", "duckdb~=1.2.2", + "databricks-switch-plugin~=0.1.4", # Temporary, until Switch is migrated to be a transpiler (LSP) plugin. "requests>=2.28.1,<3" # Matches databricks-sdk (and 'types-requests' below), to avoid conflicts. + ] [project.urls] diff --git a/src/databricks/labs/lakebridge/cli.py b/src/databricks/labs/lakebridge/cli.py index 8b2b09a61..30318393d 100644 --- a/src/databricks/labs/lakebridge/cli.py +++ b/src/databricks/labs/lakebridge/cli.py @@ -730,6 +730,7 @@ def install_transpile( w: WorkspaceClient, artifact: str | None = None, interactive: str | None = None, + include_llm_transpiler: bool = False, transpiler_repository: TranspilerRepository = TranspilerRepository.user_home(), ) -> None: """Install or upgrade the Lakebridge transpilers.""" @@ -738,9 +739,13 @@ def install_transpile( ctx.add_user_agent_extra("cmd", "install-transpile") if artifact: ctx.add_user_agent_extra("artifact-overload", Path(artifact).name) + if include_llm_transpiler: + ctx.add_user_agent_extra("include-llm-transpiler", "true") user = w.current_user logger.debug(f"User: {user}") - transpile_installer = installer(w, transpiler_repository, is_interactive=is_interactive) + transpile_installer = installer( + w, transpiler_repository, is_interactive=is_interactive, include_llm=include_llm_transpiler + ) transpile_installer.run(module="transpile", artifact=artifact) diff --git a/src/databricks/labs/lakebridge/config.py b/src/databricks/labs/lakebridge/config.py index 17d14e148..346b3f745 100644 --- a/src/databricks/labs/lakebridge/config.py +++ b/src/databricks/labs/lakebridge/config.py @@ -274,5 +274,7 @@ class ReconcileConfig: @dataclass class LakebridgeConfiguration: - transpile: TranspileConfig | None = None - reconcile: ReconcileConfig | None = None + transpile: TranspileConfig | None + reconcile: ReconcileConfig | None + # Temporary flag, indicating whether to include the LLM-based Switch transpiler. + include_switch: bool = False diff --git a/src/databricks/labs/lakebridge/contexts/application.py b/src/databricks/labs/lakebridge/contexts/application.py index f9e0875d8..b95be98de 100644 --- a/src/databricks/labs/lakebridge/contexts/application.py +++ b/src/databricks/labs/lakebridge/contexts/application.py @@ -18,6 +18,7 @@ from databricks.labs.lakebridge.deployment.dashboard import DashboardDeployment from databricks.labs.lakebridge.deployment.installation import WorkspaceInstallation from databricks.labs.lakebridge.deployment.recon import TableDeployment, JobDeployment, ReconDeployment +from databricks.labs.lakebridge.deployment.switch import SwitchDeployment from databricks.labs.lakebridge.helpers.metastore import CatalogOperations logger = logging.getLogger(__name__) @@ -119,6 +120,16 @@ def recon_deployment(self) -> ReconDeployment: self.dashboard_deployment, ) + @cached_property + def switch_deployment(self) -> SwitchDeployment: + return SwitchDeployment( + self.workspace_client, + self.installation, + self.install_state, + self.product_info, + self.job_deployment, + ) + @cached_property def workspace_installation(self) -> WorkspaceInstallation: return WorkspaceInstallation( @@ -126,6 +137,7 @@ def workspace_installation(self) -> WorkspaceInstallation: self.prompts, self.installation, self.recon_deployment, + self.switch_deployment, self.product_info, self.upgrades, ) diff --git a/src/databricks/labs/lakebridge/deployment/configurator.py b/src/databricks/labs/lakebridge/deployment/configurator.py index 191937fb9..3281d98f6 100644 --- a/src/databricks/labs/lakebridge/deployment/configurator.py +++ b/src/databricks/labs/lakebridge/deployment/configurator.py @@ -1,14 +1,18 @@ import logging import time +from collections.abc import Iterator + from databricks.labs.blueprint.tui import Prompts from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import DatabricksError from databricks.sdk.service.catalog import Privilege, SecurableType from databricks.sdk.service.sql import ( CreateWarehouseRequestWarehouseType, EndpointInfoWarehouseType, SpotInstancePolicy, ) +from databricks.sdk.service.serving import ServingEndpoint from databricks.labs.lakebridge.helpers.metastore import CatalogOperations @@ -29,8 +33,9 @@ def __init__(self, ws: WorkspaceClient, prompts: Prompts, catalog_ops: CatalogOp def prompt_for_catalog_setup( self, + default_catalog_name: str = "remorph", ) -> str: - catalog_name = self._prompts.question("Enter catalog name", default="remorph") + catalog_name = self._prompts.question("Enter catalog name", default=default_catalog_name) catalog = self._catalog_ops.get_catalog(catalog_name) if catalog: logger.info(f"Found existing catalog `{catalog_name}`") @@ -103,6 +108,35 @@ def warehouse_type(_): raise SystemExit("Cannot continue installation, without a valid warehouse. Aborting the installation.") return warehouse_id + def prompt_for_foundation_model_choice(self, default_choice: str = "databricks-claude-sonnet-4-5") -> str: + """ + List Serving Endpoints that expose a foundation model and prompt the user to pick one. + Returns the selected endpoint name + """ + endpoints: Iterator[ServingEndpoint] = self._ws.serving_endpoints.list() + + model_endpoints = [ + ep + for ep in endpoints + if ep.name + and ep.config + and ep.config.served_entities + and any(getattr(se, "foundation_model", None) is not None for se in ep.config.served_entities) + ] + + foundational_model_names = [ep.name for ep in model_endpoints if ep.name] + + if foundational_model_names is None: + raise DatabricksError("No Foundation Model serving endpoints found. Aborting the installation.") + # This logic is implemented to make the default choice always to appear first in the list + other_models = sorted(set(foundational_model_names) - {default_choice}) + choices = [f"[DEFAULT] {default_choice}", *other_models] + selected = self._prompts.choice("Select a Foundation Model serving endpoint:", choices, sort=True) + + if selected.startswith("[DEFAULT]"): + selected = default_choice + return selected + def has_necessary_catalog_access( self, catalog_name: str, user_name: str, privilege_sets: tuple[set[Privilege], ...] ): diff --git a/src/databricks/labs/lakebridge/deployment/installation.py b/src/databricks/labs/lakebridge/deployment/installation.py index 7ff283f0e..852d11121 100644 --- a/src/databricks/labs/lakebridge/deployment/installation.py +++ b/src/databricks/labs/lakebridge/deployment/installation.py @@ -13,6 +13,7 @@ from databricks.labs.lakebridge.config import LakebridgeConfiguration from databricks.labs.lakebridge.deployment.recon import ReconDeployment +from databricks.labs.lakebridge.deployment.switch import SwitchDeployment logger = logging.getLogger("databricks.labs.lakebridge.install") @@ -24,6 +25,7 @@ def __init__( prompts: Prompts, installation: Installation, recon_deployment: ReconDeployment, + switch_deployment: SwitchDeployment, product_info: ProductInfo, upgrades: Upgrades, ): @@ -31,6 +33,7 @@ def __init__( self._prompts = prompts self._installation = installation self._recon_deployment = recon_deployment + self._switch_deployment = switch_deployment self._product_info = product_info self._upgrades = upgrades @@ -96,6 +99,9 @@ def install(self, config: LakebridgeConfiguration): if config.reconcile: logger.info("Installing Lakebridge reconcile Metadata components.") self._recon_deployment.install(config.reconcile, wheel_path) + if config.include_switch: + logger.info("Installing Switch transpiler to workspace.") + self._switch_deployment.install() def uninstall(self, config: LakebridgeConfiguration): # This will remove all the Lakebridge modules @@ -116,9 +122,14 @@ def uninstall(self, config: LakebridgeConfiguration): f"Won't remove transpile validation schema `{config.transpile.schema_name}` " f"from catalog `{config.transpile.catalog_name}`. Please remove it manually." ) + self._uninstall_switch_job() if config.reconcile: self._recon_deployment.uninstall(config.reconcile) self._installation.remove() logger.info("Uninstallation completed successfully.") + + def _uninstall_switch_job(self) -> None: + """Remove Switch transpiler job if exists.""" + self._switch_deployment.uninstall() diff --git a/src/databricks/labs/lakebridge/deployment/switch.py b/src/databricks/labs/lakebridge/deployment/switch.py new file mode 100644 index 000000000..fc948435c --- /dev/null +++ b/src/databricks/labs/lakebridge/deployment/switch.py @@ -0,0 +1,182 @@ +import importlib.resources +import logging +from collections.abc import Generator, Sequence +from importlib.abc import Traversable +from pathlib import PurePosixPath +from typing import Any + +from databricks.labs import switch +from databricks.labs.switch.__about__ import __version__ as switch_version +from databricks.labs.blueprint.installation import Installation +from databricks.labs.blueprint.installer import InstallState +from databricks.labs.blueprint.paths import WorkspacePath +from databricks.labs.blueprint.wheels import ProductInfo +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import InvalidParameterValue, NotFound +from databricks.sdk.service.jobs import JobParameterDefinition, JobSettings, NotebookTask, Source, Task + +from databricks.labs.lakebridge.deployment.job import JobDeployment + +logger = logging.getLogger(__name__) + + +class SwitchDeployment: + _INSTALL_STATE_KEY = "Switch" + _TRANSPILER_ID = "switch" + + def __init__( + self, + ws: WorkspaceClient, + installation: Installation, + install_state: InstallState, + product_info: ProductInfo, + job_deployer: JobDeployment, + ): + self._ws = ws + self._installation = installation + self._install_state = install_state + self._product_info = product_info + self._job_deployer = job_deployer + + def install(self) -> None: + """Deploy Switch to workspace and configure resources.""" + logger.debug("Deploying Switch resources to workspace...") + self._deploy_resources_to_workspace() + self._setup_job() + logger.debug("Switch deployment completed") + + def uninstall(self) -> None: + """Remove Switch job from workspace.""" + if self._INSTALL_STATE_KEY not in self._install_state.jobs: + logger.debug("No Switch job found in InstallState") + return + + job_id = int(self._install_state.jobs[self._INSTALL_STATE_KEY]) + try: + logger.info(f"Removing Switch job with job_id={job_id}") + del self._install_state.jobs[self._INSTALL_STATE_KEY] + self._ws.jobs.delete(job_id) + self._install_state.save() + except (InvalidParameterValue, NotFound): + logger.warning(f"Switch job {job_id} doesn't exist anymore") + self._install_state.save() + + def _get_switch_workspace_path(self) -> WorkspacePath: + installation_root = self._installation.install_folder() + return WorkspacePath(self._ws, installation_root) / "switch" + + def _deploy_resources_to_workspace(self) -> None: + """Replicate the Switch package sources to the workspace.""" + # TODO: This is temporary, instead the jobs should directly run the code from the deployed wheel/package. + resource_root = self._get_switch_workspace_path() + # Replace existing resources, to avoid stale files and potential confusion. + if resource_root.exists(): + resource_root.rmdir(recursive=True) + resource_root.mkdir(parents=True) + already_created = {resource_root} + logger.info(f"Copying resources to {resource_root} in workspace.......") + for resource_path, resource in self._enumerate_package_files(switch): + # Resource path has a leading 'switch' that we want to strip off. + nested_path = resource_path.relative_to(PurePosixPath("switch")) + upload_path = resource_root / nested_path + if (parent := upload_path.parent) not in already_created: + logger.debug(f"Creating workspace directory: {parent}") + parent.mkdir() + already_created.add(parent) + logger.debug(f"Uploading: {resource_path} -> {upload_path}") + upload_path.write_bytes(resource.read_bytes()) + logger.info(f"Completed Copying resources to {resource_root} in workspace...") + + @staticmethod + def _enumerate_package_files(package) -> Generator[tuple[PurePosixPath, Traversable]]: + # Locate the root of the package, and then enumerate all its files recursively. + root = importlib.resources.files(package) + + def _enumerate_resources( + resource: Traversable, parent: PurePosixPath = PurePosixPath(".") + ) -> Generator[tuple[PurePosixPath, Traversable]]: + if resource.name.startswith("."): + # Skip hidden files and directories + return + if resource.is_dir(): + next_parent = parent / resource.name + for child in resource.iterdir(): + yield from _enumerate_resources(child, next_parent) + elif resource.is_file(): + # Skip hidden files and compiled Python files + if not (name := resource.name).endswith((".pyc", ".pyo")): + yield parent / name, resource + + yield from _enumerate_resources(root) + + def _setup_job(self) -> None: + """Create or update Switch job.""" + existing_job_id = self._get_existing_job_id() + logger.info("Setting up Switch job in workspace...") + try: + job_id = self._create_or_update_switch_job(existing_job_id) + self._install_state.jobs[self._INSTALL_STATE_KEY] = job_id + self._install_state.save() + job_url = f"{self._ws.config.host}/jobs/{job_id}" + logger.info(f"Switch job created/updated: {job_url}") + except (RuntimeError, ValueError, InvalidParameterValue) as e: + logger.error(f"Failed to create/update Switch job: {e}") + + def _get_existing_job_id(self) -> str | None: + """Check if Switch job already exists in workspace.""" + if self._INSTALL_STATE_KEY not in self._install_state.jobs: + return None + try: + job_id = self._install_state.jobs[self._INSTALL_STATE_KEY] + self._ws.jobs.get(int(job_id)) + return job_id + except (InvalidParameterValue, NotFound, ValueError): + return None + + def _create_or_update_switch_job(self, job_id: str | None) -> str: + """Create or update Switch job, returning job ID.""" + job_settings = self._get_switch_job_settings() + + # Try to update existing job + if job_id: + try: + logger.info(f"Updating Switch job: {job_id}") + self._ws.jobs.reset(int(job_id), JobSettings(**job_settings)) + return job_id + except (ValueError, InvalidParameterValue): + logger.warning("Previous Switch job not found, creating new one") + + # Create new job + logger.info("Creating new Switch job") + new_job = self._ws.jobs.create(**job_settings) + new_job_id = str(new_job.job_id) + assert new_job_id is not None + return new_job_id + + def _get_switch_job_settings(self) -> dict[str, Any]: + """Build job settings for Switch transpiler.""" + job_name = "Lakebridge_Switch" + notebook_path = self._get_switch_workspace_path() / "notebooks" / "00_main" + + task = Task( + task_key="run_transpilation", + notebook_task=NotebookTask(notebook_path=str(notebook_path), source=Source.WORKSPACE), + disable_auto_optimization=True, # To disable retries on failure + ) + + return { + "name": job_name, + "tags": {"created_by": self._ws.current_user.me().user_name, "switch_version": f"v{switch_version}"}, + "tasks": [task], + "parameters": self._get_switch_job_parameters(), + "max_concurrent_runs": 100, # Allow simultaneous transpilations + } + + def _get_switch_job_parameters(self) -> Sequence[JobParameterDefinition]: + # Add required runtime parameters, static for now. + parameters = { + "source_tech": "", + "input_dir": "", + "output_dir": "", + } + return [JobParameterDefinition(name=key, default=value) for key, value in parameters.items()] diff --git a/src/databricks/labs/lakebridge/install.py b/src/databricks/labs/lakebridge/install.py index ccd60772a..2ed10759e 100644 --- a/src/databricks/labs/lakebridge/install.py +++ b/src/databricks/labs/lakebridge/install.py @@ -2,14 +2,15 @@ import logging import os import webbrowser -from collections.abc import Set, Callable, Sequence +from collections.abc import Callable, Sequence, Set from pathlib import Path -from typing import Any, cast +from typing import Any from databricks.labs.blueprint.installation import Installation, JsonValue, SerdeError from databricks.labs.blueprint.installer import InstallState from databricks.labs.blueprint.tui import Prompts from databricks.labs.blueprint.wheels import ProductInfo +from databricks.labs.switch.lsp import get_switch_dialects from databricks.sdk import WorkspaceClient from databricks.sdk.errors import NotFound, PermissionDenied @@ -38,7 +39,8 @@ class WorkspaceInstaller: - def __init__( + # TODO: Temporary suppression, is_interactive is pending removal. + def __init__( # pylint: disable=too-many-arguments self, ws: WorkspaceClient, prompts: Prompts, @@ -50,6 +52,7 @@ def __init__( environ: dict[str, str] | None = None, *, is_interactive: bool = True, + include_llm: bool = False, transpiler_repository: TranspilerRepository = TranspilerRepository.user_home(), transpiler_installers: Sequence[Callable[[TranspilerRepository], TranspilerInstaller]] = ( BladebridgeInstaller, @@ -65,6 +68,7 @@ def __init__( self._ws_installation = workspace_installation # TODO: Refactor the 'prompts' property in preference to using this flag, which should be redundant. self._is_interactive = is_interactive + self._include_llm = include_llm self._transpiler_repository = transpiler_repository self._transpiler_installer_factories = transpiler_installers @@ -133,7 +137,9 @@ def configure(self, module: str) -> LakebridgeConfiguration: match module: case "transpile": logger.info("Configuring lakebridge `transpile`.") - return LakebridgeConfiguration(self._configure_transpile(), None) + return LakebridgeConfiguration( + self._configure_transpile(), reconcile=None, include_switch=self._include_llm + ) case "reconcile": logger.info("Configuring lakebridge `reconcile`.") return LakebridgeConfiguration(None, self._configure_reconcile()) @@ -142,6 +148,7 @@ def configure(self, module: str) -> LakebridgeConfiguration: return LakebridgeConfiguration( self._configure_transpile(), self._configure_reconcile(), + include_switch=self._include_llm, ) case _: raise ValueError(f"Invalid input: {module}") @@ -181,7 +188,7 @@ def _configure_new_transpile_installation(self) -> TranspileConfig: schema_name = "transpiler" if not default_config.skip_validation: catalog_name = self._configure_catalog() - schema_name = self._configure_schema(catalog_name, "transpile") + schema_name = self._configure_schema(catalog_name, schema_name) self._has_necessary_access(catalog_name, schema_name) warehouse_id = self._resource_configurator.prompt_for_warehouse_setup(TRANSPILER_WAREHOUSE_PREFIX) runtime_config = {"warehouse_id": warehouse_id} @@ -196,57 +203,87 @@ def _configure_new_transpile_installation(self) -> TranspileConfig: return config def _all_installed_dialects(self) -> list[str]: - return sorted(self._transpiler_repository.all_dialects()) + if self._include_llm: + self._switch_dialects = get_switch_dialects() + return sorted(self._transpiler_repository.all_dialects() | set(self._switch_dialects)) def _transpilers_with_dialect(self, dialect: str) -> list[str]: return sorted(self._transpiler_repository.transpilers_with_dialect(dialect)) def _transpiler_config_path(self, transpiler: str) -> Path: + if transpiler == self._llm_transpiler: + folder = self._installation.install_folder() + return Path(f"{folder}/{self._llm_transpiler}/lsp/config.yml") return self._transpiler_repository.transpiler_config_path(transpiler) + # Sentinel value for special "set it later" option in prompts. + _install_later = "Set it later" + _llm_transpiler = "Switch" + _switch_dialects: list = [] + def _prompt_for_new_transpile_installation(self) -> TranspileConfig: - install_later = "Set it later" # TODO tidy this up, logger might not display the below in console... logger.info("Please answer a few questions to configure lakebridge `transpile`") - all_dialects = [install_later, *self._all_installed_dialects()] + transpiler_name: str | None = None + + all_dialects = [self._install_later, *self._all_installed_dialects()] + source_dialect: str | None = self._prompts.choice("Select the source dialect:", all_dialects, sort=False) - if source_dialect == install_later: + if source_dialect == self._install_later: source_dialect = None - transpiler_name: str | None = None transpiler_config_path: Path | None = None + transpiler_options: dict[str, JsonValue] | None = None if source_dialect: transpilers = self._transpilers_with_dialect(source_dialect) - if len(transpilers) > 1: - transpilers = [install_later] + transpilers - transpiler_name = self._prompts.choice("Select the transpiler:", transpilers, sort=False) - if transpiler_name == install_later: - transpiler_name = None - else: - transpiler_name = next(t for t in transpilers) - logger.info(f"Lakebridge will use the {transpiler_name} transpiler") - if transpiler_name: - transpiler_config_path = self._transpiler_config_path(transpiler_name) - transpiler_options: dict[str, JsonValue] | None = None - if transpiler_config_path: - transpiler_options = self._prompt_for_transpiler_options( - cast(str, transpiler_name), cast(str, source_dialect) - ) + if self._include_llm and source_dialect in self._switch_dialects: + transpilers.append(self._llm_transpiler) + if (found_config := self._get_transpiler_config(transpilers)) is not None: + transpiler_name, transpiler_config_path = found_config + transpiler_options = self._prompt_for_transpiler_options(transpiler_name, source_dialect) input_source: str | None = self._prompts.question( - "Enter input SQL path (directory/file)", default=install_later + "Enter input SQL path (directory/file)", default=self._install_later ) - if input_source == install_later: + if input_source == self._install_later: input_source = None output_folder = self._prompts.question("Enter output directory", default="transpiled") - # When defaults are passed along we need to use absolute paths to avoid issues with relative paths + # When defaults are passed along we need to use absolute paths to avoid issues with relative paths. if output_folder == "transpiled": output_folder = str(Path.cwd() / "transpiled") error_file_path = self._prompts.question("Enter error file path", default="errors.log") if error_file_path == "errors.log": error_file_path = str(Path.cwd() / "errors.log") - run_validation = self._prompts.confirm( - "Would you like to validate the syntax and semantics of the transpiled queries?" - ) + run_validation = False + + if transpiler_name == self._llm_transpiler: + logger.info("Note: Switch transpiler is LLM Transpiler has a different execution process") + logger.info("Starting the additional configuration required for Switch...") + logger.info("Please provide the **Mandatory** following resources to set up Switch:") + catalog = self._resource_configurator.prompt_for_catalog_setup("lakebridge") + schema = self._resource_configurator.prompt_for_schema_setup(catalog, "switch") + volume = self._resource_configurator.prompt_for_volume_setup(catalog, schema, "switch_volume") + + # Prompt for a Foundation Model serving endpoint to use with the LLM converter. + logger.info("Now select the foundational model to use with Switch for LLM Transpile") + foundation_model = self._resource_configurator.prompt_for_foundation_model_choice() + + # Ensure transpiler options exist and persist the choices for later use by the transpiler. + transpiler_options = {} if transpiler_options is None else dict(transpiler_options) + transpiler_options.update( + { + "transpiler_name": transpiler_name, + "foundation_model": foundation_model, + "catalog": catalog, + "schema": schema, + "volume": volume, + } + ) + + else: + # LLM Converter bakes in LLM validation, so no need to ask the user again. + run_validation = self._prompts.confirm( + "Would you like to validate the syntax and semantics of the transpiled queries?" + ) return TranspileConfig( transpiler_config_path=str(transpiler_config_path) if transpiler_config_path is not None else None, @@ -258,6 +295,19 @@ def _prompt_for_new_transpile_installation(self) -> TranspileConfig: error_file_path=error_file_path, ) + def _get_transpiler_config(self, transpiler_names: list[str]) -> tuple[str, Path] | None: + match transpiler_names: + case [only_name]: + transpiler_name = only_name + logger.info(f"Lakebridge will use the {transpiler_name} transpiler") + case _: + choices = [self._install_later, *transpiler_names] + transpiler_name = self._prompts.choice("Select the transpiler:", choices, sort=False) + if transpiler_name == self._install_later: + return None + transpiler_config_path = self._transpiler_config_path(transpiler_name) + return transpiler_name, transpiler_config_path + def _prompt_for_transpiler_options(self, transpiler_name: str, source_dialect: str) -> dict[str, Any] | None: config_options = self._transpiler_repository.transpiler_config_options(transpiler_name, source_dialect) if len(config_options) == 0: @@ -391,6 +441,7 @@ def installer( transpiler_repository: TranspilerRepository, *, is_interactive: bool, + include_llm: bool = False, ) -> WorkspaceInstaller: app_context = ApplicationContext(_verify_workspace_client(ws)) return WorkspaceInstaller( @@ -403,6 +454,7 @@ def installer( app_context.workspace_installation, transpiler_repository=transpiler_repository, is_interactive=is_interactive, + include_llm=include_llm, ) diff --git a/tests/unit/deployment/test_configurator.py b/tests/unit/deployment/test_configurator.py index a15073115..1c078360b 100644 --- a/tests/unit/deployment/test_configurator.py +++ b/tests/unit/deployment/test_configurator.py @@ -9,6 +9,12 @@ SchemaInfo, VolumeInfo, ) +from databricks.sdk.service.serving import ( + EndpointCoreConfigSummary, + FoundationModel, + ServedEntitySpec, + ServingEndpoint, +) from databricks.sdk.service.sql import EndpointInfo, EndpointInfoWarehouseType, GetWarehouseResponse, State from databricks.labs.lakebridge.deployment.configurator import ResourceConfigurator @@ -314,3 +320,63 @@ def test_prompt_for_warehouse_setup_new(ws): catalog_operations = create_autospec(CatalogOperations) configurator = ResourceConfigurator(ws, prompts, catalog_operations) assert configurator.prompt_for_warehouse_setup("Test") == "new_w_id" + + +def test_prompt_for_foundation_model_default_choice(ws): + ws.serving_endpoints.list.return_value = [ + ServingEndpoint( + name="databricks-claude-sonnet-4-5", + config=EndpointCoreConfigSummary( + served_entities=[ + ServedEntitySpec( + foundation_model=FoundationModel(name="claude-sonnet-4.5"), + ) + ] + ), + ), + ServingEndpoint( + name="databricks-gpt-4", + config=EndpointCoreConfigSummary( + served_entities=[ + ServedEntitySpec( + foundation_model=FoundationModel(name="gpt-4"), + ) + ] + ), + ), + ] + prompts = MockPrompts({r"Select a Foundation Model serving endpoint:": "0"}) + catalog_operations = create_autospec(CatalogOperations) + configurator = ResourceConfigurator(ws, prompts, catalog_operations) + result = configurator.prompt_for_foundation_model_choice() + assert result == "databricks-claude-sonnet-4-5" + + +def test_prompt_for_foundation_model_non_default_choice(ws): + ws.serving_endpoints.list.return_value = [ + ServingEndpoint( + name="databricks-claude-sonnet-4-5", + config=EndpointCoreConfigSummary( + served_entities=[ + ServedEntitySpec( + foundation_model=FoundationModel(name="claude-sonnet-4.5"), + ) + ] + ), + ), + ServingEndpoint( + name="databricks-gpt-4", + config=EndpointCoreConfigSummary( + served_entities=[ + ServedEntitySpec( + foundation_model=FoundationModel(name="gpt-4"), + ) + ] + ), + ), + ] + prompts = MockPrompts({r"Select a Foundation Model serving endpoint:": "1"}) + catalog_operations = create_autospec(CatalogOperations) + configurator = ResourceConfigurator(ws, prompts, catalog_operations) + result = configurator.prompt_for_foundation_model_choice() + assert result == "databricks-gpt-4" diff --git a/tests/unit/deployment/test_installation.py b/tests/unit/deployment/test_installation.py index b7579bf7c..039e4a412 100644 --- a/tests/unit/deployment/test_installation.py +++ b/tests/unit/deployment/test_installation.py @@ -19,6 +19,7 @@ ) from databricks.labs.lakebridge.deployment.installation import WorkspaceInstallation from databricks.labs.lakebridge.deployment.recon import ReconDeployment +from databricks.labs.lakebridge.deployment.switch import SwitchDeployment @pytest.fixture @@ -37,6 +38,7 @@ def test_install_all(ws): } ) recon_deployment = create_autospec(ReconDeployment) + switch_deployment = create_autospec(SwitchDeployment) installation = create_autospec(Installation) product_info = create_autospec(ProductInfo) upgrades = create_autospec(Upgrades) @@ -66,13 +68,16 @@ def test_install_all(ws): ), ) config = LakebridgeConfiguration(transpile=transpile_config, reconcile=reconcile_config) - installation = WorkspaceInstallation(ws, prompts, installation, recon_deployment, product_info, upgrades) + installation = WorkspaceInstallation( + ws, prompts, installation, recon_deployment, switch_deployment, product_info, upgrades + ) installation.install(config) def test_no_recon_component_installation(ws): prompts = MockPrompts({}) recon_deployment = create_autospec(ReconDeployment) + switch_deployment = create_autospec(SwitchDeployment) installation = create_autospec(Installation) product_info = create_autospec(ProductInfo) upgrades = create_autospec(Upgrades) @@ -86,14 +91,17 @@ def test_no_recon_component_installation(ws): catalog_name="remorph7", schema_name="transpiler7", ) - config = LakebridgeConfiguration(transpile=transpile_config) - installation = WorkspaceInstallation(ws, prompts, installation, recon_deployment, product_info, upgrades) + config = LakebridgeConfiguration(transpile=transpile_config, reconcile=None) + installation = WorkspaceInstallation( + ws, prompts, installation, recon_deployment, switch_deployment, product_info, upgrades + ) installation.install(config) recon_deployment.install.assert_not_called() def test_recon_component_installation(ws): recon_deployment = create_autospec(ReconDeployment) + switch_deployment = create_autospec(SwitchDeployment) installation = create_autospec(Installation) prompts = MockPrompts({}) product_info = create_autospec(ProductInfo) @@ -114,8 +122,10 @@ def test_recon_component_installation(ws): volume="reconcile_volume8", ), ) - config = LakebridgeConfiguration(reconcile=reconcile_config) - installation = WorkspaceInstallation(ws, prompts, installation, recon_deployment, product_info, upgrades) + config = LakebridgeConfiguration(reconcile=reconcile_config, transpile=None) + installation = WorkspaceInstallation( + ws, prompts, installation, recon_deployment, switch_deployment, product_info, upgrades + ) installation.install(config) recon_deployment.install.assert_called() @@ -128,11 +138,14 @@ def test_negative_uninstall_confirmation(ws): ) installation = create_autospec(Installation) recon_deployment = create_autospec(ReconDeployment) + switch_deployment = create_autospec(SwitchDeployment) wheels = create_autospec(WheelsV2) upgrades = create_autospec(Upgrades) - ws_installation = WorkspaceInstallation(ws, prompts, installation, recon_deployment, wheels, upgrades) - config = LakebridgeConfiguration() + ws_installation = WorkspaceInstallation( + ws, prompts, installation, recon_deployment, switch_deployment, wheels, upgrades + ) + config = LakebridgeConfiguration(transpile=None, reconcile=None) ws_installation.uninstall(config) installation.remove.assert_not_called() @@ -147,11 +160,14 @@ def test_missing_installation(ws): installation.files.side_effect = NotFound("Installation not found") installation.install_folder.return_value = "~/mock" recon_deployment = create_autospec(ReconDeployment) + switch_deployment = create_autospec(SwitchDeployment) wheels = create_autospec(WheelsV2) upgrades = create_autospec(Upgrades) - ws_installation = WorkspaceInstallation(ws, prompts, installation, recon_deployment, wheels, upgrades) - config = LakebridgeConfiguration() + ws_installation = WorkspaceInstallation( + ws, prompts, installation, recon_deployment, switch_deployment, wheels, upgrades + ) + config = LakebridgeConfiguration(transpile=None, reconcile=None) ws_installation.uninstall(config) installation.remove.assert_not_called() @@ -193,10 +209,13 @@ def test_uninstall_configs_exist(ws): config = LakebridgeConfiguration(transpile=transpile_config, reconcile=reconcile_config) installation = MockInstallation({}) recon_deployment = create_autospec(ReconDeployment) + switch_deployment = create_autospec(SwitchDeployment) wheels = create_autospec(WheelsV2) upgrades = create_autospec(Upgrades) - ws_installation = WorkspaceInstallation(ws, prompts, installation, recon_deployment, wheels, upgrades) + ws_installation = WorkspaceInstallation( + ws, prompts, installation, recon_deployment, switch_deployment, wheels, upgrades + ) ws_installation.uninstall(config) recon_deployment.uninstall.assert_called() installation.assert_removed() @@ -210,11 +229,36 @@ def test_uninstall_configs_missing(ws): ) installation = MockInstallation() recon_deployment = create_autospec(ReconDeployment) + switch_deployment = create_autospec(SwitchDeployment) wheels = create_autospec(WheelsV2) upgrades = create_autospec(Upgrades) - ws_installation = WorkspaceInstallation(ws, prompts, installation, recon_deployment, wheels, upgrades) - config = LakebridgeConfiguration() + ws_installation = WorkspaceInstallation( + ws, prompts, installation, recon_deployment, switch_deployment, wheels, upgrades + ) + config = LakebridgeConfiguration(transpile=None, reconcile=None) ws_installation.uninstall(config) recon_deployment.uninstall.assert_not_called() installation.assert_removed() + + +class TestSwitchInstallation: + """Tests for Switch transpiler installation.""" + + def test_switch_install_uses_configured_resources(self, ws): + prompts = MockPrompts({}) + recon_deployment = create_autospec(ReconDeployment) + switch_deployment = create_autospec(SwitchDeployment) + installation = create_autospec(Installation) + product_info = create_autospec(ProductInfo) + upgrades = create_autospec(Upgrades) + + config = LakebridgeConfiguration(transpile=TranspileConfig(), reconcile=None, include_switch=True) + + ws_installation = WorkspaceInstallation( + ws, prompts, installation, recon_deployment, switch_deployment, product_info, upgrades + ) + + ws_installation.install(config) + + switch_deployment.install.assert_called_once() diff --git a/tests/unit/deployment/test_switch.py b/tests/unit/deployment/test_switch.py new file mode 100644 index 000000000..b61bd81ed --- /dev/null +++ b/tests/unit/deployment/test_switch.py @@ -0,0 +1,324 @@ +from unittest.mock import create_autospec +from typing import Any, cast +import pytest + +from databricks.labs.blueprint.installation import MockInstallation +from databricks.labs.blueprint.installer import InstallState +from databricks.labs.blueprint.wheels import ProductInfo +from databricks.labs.lakebridge.config import LakebridgeConfiguration +from databricks.labs.lakebridge.deployment.job import JobDeployment +from databricks.labs.lakebridge.deployment.switch import SwitchDeployment +from databricks.sdk import WorkspaceClient, JobsExt +from databricks.sdk.errors import NotFound, InvalidParameterValue +from databricks.sdk.service.jobs import CreateResponse +from databricks.sdk.service.iam import User + + +@pytest.fixture() +def mock_workspace_client() -> WorkspaceClient: + ws: Any = create_autospec(WorkspaceClient, instance=True) + ws.current_user.me.return_value = User(user_name="test_user") + ws.config.host = "https://test.databricks.com" + ws.jobs = cast(Any, create_autospec(JobsExt, instance=True)) + return ws + + +@pytest.fixture() +def installation() -> MockInstallation: + return MockInstallation(is_global=False) + + +@pytest.fixture() +def install_state(installation: MockInstallation) -> InstallState: + return InstallState.from_installation(installation) + + +@pytest.fixture() +def product_info() -> ProductInfo: + return ProductInfo.for_testing(LakebridgeConfiguration) + + +@pytest.fixture() +def job_deployer() -> JobDeployment: + return create_autospec(JobDeployment, instance=True) + + +@pytest.fixture() +def switch_deployment( + mock_workspace_client: Any, + installation: MockInstallation, + install_state: InstallState, + product_info: ProductInfo, + job_deployer: JobDeployment, +) -> SwitchDeployment: + return SwitchDeployment(mock_workspace_client, installation, install_state, product_info, job_deployer) + + +def test_install_creates_job_successfully( + switch_deployment: SwitchDeployment, mock_workspace_client: Any, install_state: InstallState +) -> None: + """Test successful installation creates job and saves state.""" + mock_workspace_client.jobs.create.return_value = CreateResponse(job_id=123) + + switch_deployment.install() + + assert install_state.jobs["Switch"] == "123" + mock_workspace_client.jobs.create.assert_called_once() + + +def test_install_updates_existing_job( + switch_deployment: SwitchDeployment, mock_workspace_client: Any, install_state: InstallState +) -> None: + """Test installation updates existing job if found.""" + install_state.jobs["Switch"] = "456" + mock_workspace_client.jobs.get.return_value = create_autospec(CreateResponse, instance=True) + + switch_deployment.install() + + assert install_state.jobs["Switch"] == "456" + mock_workspace_client.jobs.reset.assert_called_once() + mock_workspace_client.jobs.create.assert_not_called() + + +def test_install_creates_new_job_when_existing_not_found( + switch_deployment: SwitchDeployment, mock_workspace_client: Any, install_state: InstallState +) -> None: + """Test installation creates new job when existing job is not found.""" + install_state.jobs["Switch"] = "789" + mock_workspace_client.jobs.get.side_effect = NotFound("Job not found") + mock_workspace_client.jobs.create.return_value = CreateResponse(job_id=999) + + switch_deployment.install() + + assert install_state.jobs["Switch"] == "999" + mock_workspace_client.jobs.create.assert_called_once() + + +def test_install_handles_job_creation_error( + switch_deployment: SwitchDeployment, mock_workspace_client: Any, install_state: InstallState +) -> None: + """Test installation handles job creation errors gracefully.""" + mock_workspace_client.jobs.create.side_effect = RuntimeError("Job creation failed") + + switch_deployment.install() + + # State should not be updated on error + assert "Switch" not in install_state.jobs + + +def test_install_handles_invalid_parameter_error( + switch_deployment: SwitchDeployment, mock_workspace_client: Any, install_state: InstallState +) -> None: + """Test installation handles invalid parameter errors gracefully.""" + mock_workspace_client.jobs.create.side_effect = InvalidParameterValue("Invalid parameter") + + switch_deployment.install() + + assert "Switch" not in install_state.jobs + + +def test_install_fallback_on_update_failure( + switch_deployment: SwitchDeployment, mock_workspace_client: Any, install_state: InstallState +) -> None: + install_state.jobs["Switch"] = "555" + mock_workspace_client.jobs.get.return_value = create_autospec(CreateResponse, instance=True) + mock_workspace_client.jobs.reset.side_effect = InvalidParameterValue("Update failed") + new_job = CreateResponse(job_id=666) + mock_workspace_client.jobs.create.return_value = new_job + + switch_deployment.install() + + assert install_state.jobs["Switch"] == "666" + mock_workspace_client.jobs.reset.assert_called_once() + mock_workspace_client.jobs.create.assert_called_once() + + +def test_install_with_invalid_existing_job_id( + switch_deployment: SwitchDeployment, mock_workspace_client: Any, install_state: InstallState +) -> None: + install_state.jobs["Switch"] = "not_a_number" + mock_workspace_client.jobs.get.side_effect = ValueError("Invalid job ID") + new_job = CreateResponse(job_id=777) + mock_workspace_client.jobs.create.return_value = new_job + + switch_deployment.install() + + assert install_state.jobs["Switch"] == "777" + mock_workspace_client.jobs.create.assert_called_once() + + +def test_install_preserves_other_jobs_in_state( + switch_deployment: SwitchDeployment, mock_workspace_client: Any, install_state: InstallState +) -> None: + install_state.jobs["OtherJob"] = "999" + new_job = CreateResponse(job_id=123) + mock_workspace_client.jobs.create.return_value = new_job + + switch_deployment.install() + + assert install_state.jobs["Switch"] == "123" + assert install_state.jobs["OtherJob"] == "999" + + +def test_install_configures_job_with_correct_parameters( + switch_deployment: SwitchDeployment, mock_workspace_client: Any +) -> None: + """Test installation configures job with correct parameters.""" + new_job = CreateResponse(job_id=123) + mock_workspace_client.jobs.create.return_value = new_job + + switch_deployment.install() + + # Verify job creation was called with settings + mock_workspace_client.jobs.create.assert_called_once() + call_kwargs = mock_workspace_client.jobs.create.call_args.kwargs + + # Verify job name + assert call_kwargs["name"] == "Lakebridge_Switch" + + # Verify tags + assert "created_by" in call_kwargs["tags"] + assert call_kwargs["tags"]["created_by"] == "test_user" + assert "switch_version" in call_kwargs["tags"] + + # Verify tasks + assert len(call_kwargs["tasks"]) == 1 + assert call_kwargs["tasks"][0].task_key == "run_transpilation" + assert call_kwargs["tasks"][0].disable_auto_optimization is True + + # Verify parameters + param_names = {param.name for param in call_kwargs["parameters"]} + assert param_names == {"source_tech", "input_dir", "output_dir"} + + # Verify max concurrent runs + assert call_kwargs["max_concurrent_runs"] == 100 + + +def test_install_configures_job_with_correct_notebook_path( + switch_deployment: SwitchDeployment, mock_workspace_client: Any, installation: MockInstallation +) -> None: + """Test installation configures job with correct notebook path.""" + new_job = CreateResponse(job_id=123) + mock_workspace_client.jobs.create.return_value = new_job + + switch_deployment.install() + + call_kwargs = mock_workspace_client.jobs.create.call_args.kwargs + notebook_path = call_kwargs["tasks"][0].notebook_task.notebook_path + + # Verify notebook path includes switch directory and notebook name + assert "switch" in notebook_path + assert "notebooks" in notebook_path + assert "00_main" in notebook_path + + +def test_uninstall_removes_job_successfully( + switch_deployment: SwitchDeployment, mock_workspace_client: Any, install_state: InstallState +) -> None: + install_state.jobs["Switch"] = "123" + + switch_deployment.uninstall() + + assert "Switch" not in install_state.jobs + mock_workspace_client.jobs.delete.assert_called_once_with(123) + + +def test_uninstall_handles_job_not_found( + switch_deployment: SwitchDeployment, mock_workspace_client: Any, install_state: InstallState +) -> None: + install_state.jobs["Switch"] = "456" + mock_workspace_client.jobs.delete.side_effect = NotFound("Job not found") + + switch_deployment.uninstall() + + assert "Switch" not in install_state.jobs + + +def test_uninstall_handles_invalid_parameter( + switch_deployment: SwitchDeployment, mock_workspace_client: Any, install_state: InstallState +) -> None: + install_state.jobs["Switch"] = "789" + mock_workspace_client.jobs.delete.side_effect = InvalidParameterValue("Invalid job ID") + + switch_deployment.uninstall() + + assert "Switch" not in install_state.jobs + + +def test_uninstall_no_job_in_state(switch_deployment: SwitchDeployment, mock_workspace_client: Any) -> None: + switch_deployment.uninstall() + + mock_workspace_client.jobs.delete.assert_not_called() + + +def test_uninstall_with_invalid_job_id_format( + switch_deployment: SwitchDeployment, mock_workspace_client: Any, install_state: InstallState +) -> None: + install_state.jobs["Switch"] = "not_a_number" + + # Should raise ValueError when trying to convert to int + with pytest.raises(ValueError): + switch_deployment.uninstall() + + +def test_uninstall_preserves_other_jobs_in_state( + switch_deployment: SwitchDeployment, mock_workspace_client: Any, install_state: InstallState +) -> None: + install_state.jobs["Switch"] = "123" + install_state.jobs["OtherJob"] = "999" + + switch_deployment.uninstall() + + assert "Switch" not in install_state.jobs + assert install_state.jobs["OtherJob"] == "999" + + +# Parameterized tests + + +@pytest.mark.parametrize( + "exception", + [ + NotFound("Job not found"), + InvalidParameterValue("Invalid parameter"), + ], +) +def test_uninstall_handles_exceptions( + switch_deployment: SwitchDeployment, + mock_workspace_client: Any, + install_state: InstallState, + exception, +) -> None: + install_state.jobs["Switch"] = "123" + mock_workspace_client.jobs.delete.side_effect = exception + + switch_deployment.uninstall() + + assert "Switch" not in install_state.jobs + + +@pytest.mark.parametrize( + "exception,expected_job_id", + [ + (InvalidParameterValue("Update failed"), 888), + (ValueError("Invalid job ID"), 777), + ], +) +def test_install_creates_new_job_on_update_failure( + switch_deployment: SwitchDeployment, + mock_workspace_client: Any, + install_state: InstallState, + exception, + expected_job_id, +) -> None: + install_state.jobs["Switch"] = "555" + mock_workspace_client.jobs.get.return_value = create_autospec(CreateResponse, instance=True) + mock_workspace_client.jobs.reset.side_effect = exception + new_job = CreateResponse(job_id=expected_job_id) + mock_workspace_client.jobs.create.return_value = new_job + + switch_deployment.install() + + assert install_state.jobs["Switch"] == str(expected_job_id) + mock_workspace_client.jobs.reset.assert_called_once() + mock_workspace_client.jobs.create.assert_called_once() diff --git a/tests/unit/test_install.py b/tests/unit/test_install.py index 5d197a226..1b0e05ea1 100644 --- a/tests/unit/test_install.py +++ b/tests/unit/test_install.py @@ -9,6 +9,7 @@ from databricks.sdk.service import iam from databricks.labs.blueprint.tui import MockPrompts from databricks.labs.blueprint.wheels import ProductInfo, WheelsV2 + from databricks.labs.lakebridge.config import ( DatabaseConfig, LSPConfigOptionV1, @@ -42,10 +43,10 @@ def ws() -> WorkspaceClient: SET_IT_LATER = ["Set it later"] -ALL_INSTALLED_DIALECTS_NO_LATER = sorted(["tsql", "snowflake"]) +ALL_INSTALLED_DIALECTS_NO_LATER = sorted(["tsql", "snowflake", "mssql"]) ALL_INSTALLED_DIALECTS = SET_IT_LATER + ALL_INSTALLED_DIALECTS_NO_LATER -TRANSPILERS_FOR_SNOWFLAKE_NO_LATER = sorted(["Remorph Community Transpiler", "Morpheus"]) -TRANSPILERS_FOR_SNOWFLAKE = SET_IT_LATER + TRANSPILERS_FOR_SNOWFLAKE_NO_LATER +TRANSPILERS_LIST_NO_LATER = sorted(["Bladebridge", "Morpheus", "Remorph Community Transpiler", "Switch"]) +TRANSPILERS_LIST = SET_IT_LATER + TRANSPILERS_LIST_NO_LATER PATH_TO_TRANSPILER_CONFIG = "/some/path/to/config.yml" @@ -67,7 +68,7 @@ def _all_installed_dialects(self): return ALL_INSTALLED_DIALECTS_NO_LATER def _transpilers_with_dialect(self, dialect): - return TRANSPILERS_FOR_SNOWFLAKE_NO_LATER + return TRANSPILERS_LIST_NO_LATER def _transpiler_config_path(self, transpiler): return PATH_TO_TRANSPILER_CONFIG @@ -106,7 +107,7 @@ def test_workspace_installer_run_install_not_called_in_test( workspace_installation=ws_installation, ) - provided_config = LakebridgeConfiguration() + provided_config = LakebridgeConfiguration(transpile=None, reconcile=None) workspace_installer = ws_installer( ctx.workspace_client, ctx.prompts, @@ -133,7 +134,7 @@ def test_workspace_installer_run_install_called_with_provided_config( resource_configurator=create_autospec(ResourceConfigurator), workspace_installation=ws_installation, ) - provided_config = LakebridgeConfiguration() + provided_config = LakebridgeConfiguration(transpile=None, reconcile=None) workspace_installer = ws_installer( ctx.workspace_client, ctx.prompts, @@ -178,7 +179,7 @@ def test_workspace_installer_run_install_called_with_generated_config( { r"Do you want to override the existing installation?": "no", r"Select the source dialect": str(ALL_INSTALLED_DIALECTS.index("snowflake")), - r"Select the transpiler": str(TRANSPILERS_FOR_SNOWFLAKE.index("Morpheus")), + r"Select the transpiler": str(TRANSPILERS_LIST.index("Morpheus")), r"Enter input SQL path.*": "/tmp/queries/snow", r"Enter output directory.*": "/tmp/queries/databricks", r"Enter error file path.*": "/tmp/queries/errors.log", @@ -229,7 +230,7 @@ def test_configure_transpile_no_existing_installation( { r"Do you want to override the existing installation?": "no", r"Select the source dialect": str(ALL_INSTALLED_DIALECTS.index("snowflake")), - r"Select the transpiler": str(TRANSPILERS_FOR_SNOWFLAKE.index("Morpheus")), + r"Select the transpiler": str(TRANSPILERS_LIST.index("Morpheus")), r"Enter input SQL path.*": "/tmp/queries/snow", r"Enter output directory.*": "/tmp/queries/databricks", r"Enter error file path.*": "/tmp/queries/errors.log", @@ -267,7 +268,7 @@ def test_configure_transpile_no_existing_installation( catalog_name="remorph", schema_name="transpiler", ) - expected_config = LakebridgeConfiguration(transpile=expected_morph_config) + expected_config = LakebridgeConfiguration(transpile=expected_morph_config, reconcile=None) assert config == expected_config installation.assert_file_written( "config.yml", @@ -343,7 +344,7 @@ def test_configure_transpile_installation_config_error_continue_install( { r"Do you want to override the existing installation?": "yes", r"Select the source dialect": str(ALL_INSTALLED_DIALECTS.index("snowflake")), - r"Select the transpiler": str(TRANSPILERS_FOR_SNOWFLAKE.index("Morpheus")), + r"Select the transpiler": str(TRANSPILERS_LIST.index("Morpheus")), r"Enter input SQL path.*": "/tmp/queries/snow", r"Enter output directory.*": "/tmp/queries/databricks", r"Enter error file path.*": "/tmp/queries/errors.log", @@ -398,7 +399,7 @@ def test_configure_transpile_installation_config_error_continue_install( catalog_name="remorph", schema_name="transpiler", ) - expected_config = LakebridgeConfiguration(transpile=expected_morph_config) + expected_config = LakebridgeConfiguration(transpile=expected_morph_config, reconcile=None) assert config == expected_config installation.assert_file_written( "config.yml", @@ -421,7 +422,7 @@ def test_configure_transpile_installation_with_no_validation(ws, ws_installer): prompts = MockPrompts( { r"Select the source dialect": ALL_INSTALLED_DIALECTS.index("snowflake"), - r"Select the transpiler": TRANSPILERS_FOR_SNOWFLAKE.index("Morpheus"), + r"Select the transpiler": TRANSPILERS_LIST.index("Morpheus"), r"Enter input SQL path.*": "/tmp/queries/snow", r"Enter output directory.*": "/tmp/queries/databricks", r"Enter error file path.*": "/tmp/queries/errors.log", @@ -461,7 +462,7 @@ def test_configure_transpile_installation_with_no_validation(ws, ws_installer): catalog_name="remorph", schema_name="transpiler", ) - expected_config = LakebridgeConfiguration(transpile=expected_morph_config) + expected_config = LakebridgeConfiguration(transpile=expected_morph_config, reconcile=None) assert config == expected_config installation.assert_file_written( "config.yml", @@ -486,7 +487,7 @@ def test_configure_transpile_installation_with_validation_and_warehouse_id_from_ prompts = MockPrompts( { r"Select the source dialect": str(ALL_INSTALLED_DIALECTS.index("snowflake")), - r"Select the transpiler": str(TRANSPILERS_FOR_SNOWFLAKE.index("Morpheus")), + r"Select the transpiler": str(TRANSPILERS_LIST.index("Morpheus")), r"Enter input SQL path.*": "/tmp/queries/snow", r"Enter output directory.*": "/tmp/queries/databricks", r"Enter error file path.*": "/tmp/queries/errors.log", @@ -532,7 +533,8 @@ def test_configure_transpile_installation_with_validation_and_warehouse_id_from_ catalog_name="remorph_test", schema_name="transpiler_test", sdk_config={"warehouse_id": "w_id"}, - ) + ), + reconcile=None, ) assert config == expected_config installation.assert_file_written( @@ -669,7 +671,8 @@ def test_configure_reconcile_installation_config_error_continue_install(ws: Work schema="reconcile", volume="reconcile_volume", ), - ) + ), + transpile=None, ) assert config == expected_config installation.assert_file_written( @@ -748,7 +751,8 @@ def test_configure_reconcile_no_existing_installation(ws: WorkspaceClient) -> No schema="reconcile", volume="reconcile_volume", ), - ) + ), + transpile=None, ) assert config == expected_config installation.assert_file_written( @@ -781,7 +785,7 @@ def test_configure_all_override_installation( { r"Do you want to override the existing installation?": "yes", r"Select the source dialect": str(ALL_INSTALLED_DIALECTS.index("snowflake")), - r"Select the transpiler": str(TRANSPILERS_FOR_SNOWFLAKE.index("Morpheus")), + r"Select the transpiler": str(TRANSPILERS_LIST.index("Morpheus")), r"Enter input SQL path.*": "/tmp/queries/snow", r"Enter output directory.*": "/tmp/queries/databricks", r"Enter error file path.*": "/tmp/queries/errors.log", @@ -957,7 +961,7 @@ def test_runs_upgrades_on_more_recent_version( { r"Do you want to override the existing installation?": "yes", r"Select the source dialect": str(ALL_INSTALLED_DIALECTS.index("snowflake")), - r"Select the transpiler": str(TRANSPILERS_FOR_SNOWFLAKE.index("Morpheus")), + r"Select the transpiler": str(TRANSPILERS_LIST.index("Morpheus")), r"Enter input SQL path.*": "/tmp/queries/snow", r"Enter output directory.*": "/tmp/queries/databricks", r"Enter error file.*": "/tmp/queries/errors.log", @@ -1001,7 +1005,8 @@ def test_runs_upgrades_on_more_recent_version( catalog_name="remorph", schema_name="transpiler", skip_validation=True, - ) + ), + reconcile=None, ) ) @@ -1014,7 +1019,7 @@ def test_runs_and_stores_confirm_config_option( prompts = MockPrompts( { r"Select the source dialect": str(ALL_INSTALLED_DIALECTS.index("snowflake")), - r"Select the transpiler": str(TRANSPILERS_FOR_SNOWFLAKE.index("Remorph Community Transpiler")), + r"Select the transpiler": str(TRANSPILERS_LIST.index("Remorph Community Transpiler")), r"Do you want to use the experimental Databricks generator ?": "yes", r"Enter input SQL path.*": "/tmp/queries/snow", r"Enter output directory.*": "/tmp/queries/databricks", @@ -1070,7 +1075,8 @@ def transpilers_path(self) -> Path: catalog_name="remorph_test", schema_name="transpiler_test", sdk_config={"warehouse_id": "w_id"}, - ) + ), + reconcile=None, ) assert config == expected_config installation.assert_file_written( @@ -1107,7 +1113,7 @@ def test_runs_and_stores_force_config_option( prompts = MockPrompts( { r"Select the source dialect": str(ALL_INSTALLED_DIALECTS.index("snowflake")), - r"Select the transpiler": str(TRANSPILERS_FOR_SNOWFLAKE.index("Remorph Community Transpiler")), + r"Select the transpiler": str(TRANSPILERS_LIST.index("Remorph Community Transpiler")), r"Enter input SQL path.*": "/tmp/queries/snow", r"Enter output directory.*": "/tmp/queries/databricks", r"Enter error file path.*": "/tmp/queries/errors.log", @@ -1158,7 +1164,8 @@ def test_runs_and_stores_force_config_option( catalog_name="remorph_test", schema_name="transpiler_test", sdk_config={"warehouse_id": "w_id"}, - ) + ), + reconcile=None, ) assert config == expected_config installation.assert_file_written( @@ -1186,7 +1193,7 @@ def test_runs_and_stores_question_config_option( prompts = MockPrompts( { r"Select the source dialect": str(ALL_INSTALLED_DIALECTS.index("snowflake")), - r"Select the transpiler": str(TRANSPILERS_FOR_SNOWFLAKE.index("Remorph Community Transpiler")), + r"Select the transpiler": str(TRANSPILERS_LIST.index("Remorph Community Transpiler")), r"Max number of heaps:": "1254", r"Enter input SQL path.*": "/tmp/queries/snow", r"Enter output directory.*": "/tmp/queries/databricks", @@ -1239,7 +1246,8 @@ def test_runs_and_stores_question_config_option( catalog_name="remorph_test", schema_name="transpiler_test", sdk_config={"warehouse_id": "w_id"}, - ) + ), + reconcile=None, ) assert config == expected_config installation.assert_file_written( @@ -1267,7 +1275,7 @@ def test_runs_and_stores_choice_config_option( prompts = MockPrompts( { r"Select the source dialect": str(ALL_INSTALLED_DIALECTS.index("snowflake")), - r"Select the transpiler": str(TRANSPILERS_FOR_SNOWFLAKE.index("Remorph Community Transpiler")), + r"Select the transpiler": str(TRANSPILERS_LIST.index("Remorph Community Transpiler")), r"Select currency:": "2", r"Enter input SQL path.*": "/tmp/queries/snow", r"Enter output directory.*": "/tmp/queries/databricks", @@ -1326,7 +1334,8 @@ def test_runs_and_stores_choice_config_option( catalog_name="remorph_test", schema_name="transpiler_test", sdk_config={"warehouse_id": "w_id"}, - ) + ), + reconcile=None, ) assert config == expected_config installation.assert_file_written( @@ -1467,7 +1476,7 @@ def test_installer_upgrade_configure_if_changed( prompts=MockPrompts( { r"Do you want to override the existing installation?": "yes", - r"Select the source dialect": "2", + r"Select the source dialect": str(ALL_INSTALLED_DIALECTS.index("tsql")), r"Select the transpiler": "1", r"Enter .*": "/tmp/updated", r"Would you like to validate.*": "no", @@ -1600,3 +1609,111 @@ def test_no_configure_if_noninteractive( assert config.transpile is None expected_log_message = "Installation is not interactive, skipping configuration of transpilers." assert any(expected_log_message in log.message for log in caplog.records if log.levelno == logging.WARNING) + + +@pytest.mark.xfail(raises=ValueError, reason="Reconcile module can't yet be configured in non-interactive mode") +@pytest.mark.parametrize( + ("include_llm_transpiler", "should_include_switch"), + ( + (False, False), # Default: exclude Switch + (True, True), # Flag enabled: include Switch + (None, False), # Not specified: default behavior (exclude Switch) + ), +) +def test_transpiler_installers_llm_flag( + ws_installer: Callable[..., WorkspaceInstaller], + ws: WorkspaceClient, + include_llm_transpiler: bool | None, + should_include_switch: bool, +) -> None: + """Test switch configuration flag based on the include_llm parameter.""" + ctx = ApplicationContext(ws).replace( + product_info=ProductInfo.for_testing(LakebridgeConfiguration), + prompts=MockPrompts({}), + installation=MockInstallation({}), + ) + kw_args = {"include_llm": include_llm_transpiler} if include_llm_transpiler is not None else {} + installer = ws_installer( + ctx.workspace_client, + ctx.prompts, + ctx.installation, + ctx.install_state, + ctx.product_info, + ctx.resource_configurator, + ctx.workspace_installation, + is_interactive=False, + **kw_args, + ) + assert installer.configure("transpile").include_switch == should_include_switch + assert installer.configure("all").include_switch == should_include_switch + + +def test_workspace_installer_run_install_called_with_generated_config_switch( + ws_installer: Callable[..., WorkspaceInstaller], + ws: WorkspaceClient, +) -> None: + prompts = MockPrompts( + { + r"Do you want to override the existing installation?": "no", + r"Select the source dialect": str(ALL_INSTALLED_DIALECTS.index("mssql")), + r"Select the transpiler": str(TRANSPILERS_LIST.index("Switch")), + r"Enter input SQL path.*": "/tmp/queries/mssql", + r"Enter output directory.*": "/tmp/queries/databricks", + r"Enter error file path.*": "/tmp/queries/errors.log", + r"Would you like to validate.*": "no", + r"Open .* in the browser?": "no", + r"Enter catalog name.*": "lakebridge", + r"Enter schema name.*": "switch", + r"Enter volume name.*": "switch_volume", + r"Select a Foundation Model serving endpoint.*": "0", + } + ) + + installation = MockInstallation() + resource_configurator = create_autospec(ResourceConfigurator) + resource_configurator.prompt_for_catalog_setup.return_value = "lakebridge" + resource_configurator.prompt_for_schema_setup.return_value = "switch" + resource_configurator.prompt_for_volume_setup.return_value = "switch_volume" + resource_configurator.prompt_for_foundation_model_choice.return_value = "databricks-claude-sonnet-4-5" + ctx = ApplicationContext(ws) + + ctx.replace( + prompts=prompts, + installation=installation, + resource_configurator=resource_configurator, + workspace_installation=create_autospec(WorkspaceInstallation), + ) + + workspace_installer = ws_installer( + ctx.workspace_client, + ctx.prompts, + ctx.installation, + ctx.install_state, + ctx.product_info, + ctx.resource_configurator, + ctx.workspace_installation, + is_interactive=True, + include_llm=True, + ) + workspace_installer.run("transpile") + installation.assert_file_written( + "config.yml", + { + "catalog_name": "remorph", + "transpiler_config_path": PATH_TO_TRANSPILER_CONFIG, + "source_dialect": "mssql", + "input_source": "/tmp/queries/mssql", + "output_folder": "/tmp/queries/databricks", + "error_file_path": "/tmp/queries/errors.log", + "schema_name": "transpiler", + "skip_validation": True, + "transpiler_options": { + "catalog": "lakebridge", + "foundation_model": "databricks-claude-sonnet-4-5", + "schema": "switch", + "transpiler_name": "Switch", + "volume": "switch_volume", + }, + "version": 3, + }, + ) diff --git a/tests/unit/test_uninstall.py b/tests/unit/test_uninstall.py index f2a980e6b..0641aece5 100644 --- a/tests/unit/test_uninstall.py +++ b/tests/unit/test_uninstall.py @@ -24,7 +24,7 @@ def test_uninstaller_run(ws): ctx = ApplicationContext(ws) ctx.replace( workspace_installation=ws_installation, - remorph_config=LakebridgeConfiguration(), + remorph_config=LakebridgeConfiguration(transpile=None, reconcile=None), ) uninstall.run(ctx) ws_installation.uninstall.assert_called_once()