Skip to content

Commit ff97db8

Browse files
authored
Added get_tables_to_migrate functionality in the mapping module (#755)
1 parent f329875 commit ff97db8

File tree

8 files changed

+705
-124
lines changed

8 files changed

+705
-124
lines changed

src/databricks/labs/ucx/cli.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,11 @@ def skip(w: WorkspaceClient, schema: str | None = None, table: str | None = None
6464
return None
6565
warehouse_id = installation.config.warehouse_id
6666
sql_backend = StatementExecutionBackend(w, warehouse_id)
67-
mapping = TableMapping(w)
67+
mapping = TableMapping(w, sql_backend)
6868
if table:
69-
mapping.skip_table(sql_backend, schema, table)
69+
mapping.skip_table(schema, table)
7070
else:
71-
mapping.skip_schema(sql_backend, schema)
71+
mapping.skip_schema(schema)
7272

7373

7474
@ucx.command(is_account=True)
@@ -90,7 +90,10 @@ def manual_workspace_info(w: WorkspaceClient):
9090
@ucx.command
9191
def create_table_mapping(w: WorkspaceClient):
9292
"""create initial table mapping for review"""
93-
table_mapping = TableMapping(w)
93+
installation_manager = InstallationManager(w)
94+
installation = installation_manager.for_user(w.current_user.me())
95+
sql_backend = StatementExecutionBackend(w, installation.config.warehouse_id)
96+
table_mapping = TableMapping(w, sql_backend)
9497
workspace_info = WorkspaceInfo(w)
9598
installation_manager = InstallationManager(w)
9699
installation = installation_manager.for_user(w.current_user.me())
@@ -121,9 +124,8 @@ def ensure_assessment_run(w: WorkspaceClient):
121124
if not installation:
122125
logger.error(CANT_FIND_UCX_MSG)
123126
return None
124-
else:
125-
workspace_installer = WorkspaceInstaller(w)
126-
workspace_installer.validate_and_run("assessment")
127+
workspace_installer = WorkspaceInstaller(w)
128+
workspace_installer.validate_and_run("assessment")
127129

128130

129131
@ucx.command
@@ -155,7 +157,7 @@ def revert_migrated_tables(w: WorkspaceClient, schema: str, table: str, *, delet
155157
warehouse_id = installation.config.warehouse_id
156158
sql_backend = StatementExecutionBackend(w, warehouse_id)
157159
table_crawler = TablesCrawler(sql_backend, installation.config.inventory_database)
158-
tmp = TableMapping(w)
160+
tmp = TableMapping(w, sql_backend)
159161
tm = TablesMigrate(table_crawler, w, sql_backend, tmp)
160162
if tm.print_revert_report(delete_managed=delete_managed) and prompts.confirm(
161163
"Would you like to continue?", max_attempts=2

src/databricks/labs/ucx/hive_metastore/mapping.py

Lines changed: 100 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
import logging
55
import re
66
from dataclasses import dataclass
7+
from functools import partial
78

9+
from databricks.labs.blueprint.parallel import Threads
810
from databricks.sdk import WorkspaceClient
9-
from databricks.sdk.errors import BadRequest, NotFound
11+
from databricks.sdk.errors import BadRequest, NotFound, ResourceConflict
1012
from databricks.sdk.service.workspace import ImportFormat
1113

1214
from databricks.labs.ucx.account import WorkspaceInfo
13-
from databricks.labs.ucx.framework.crawlers import StatementExecutionBackend
15+
from databricks.labs.ucx.framework.crawlers import SqlBackend
1416
from databricks.labs.ucx.hive_metastore import TablesCrawler
1517
from databricks.labs.ucx.hive_metastore.tables import Table
1618

@@ -46,14 +48,21 @@ def as_hms_table_key(self):
4648
return f"hive_metastore.{self.src_schema}.{self.src_table}"
4749

4850

51+
@dataclass
52+
class TableToMigrate:
53+
src: Table
54+
rule: Rule
55+
56+
4957
class TableMapping:
5058
UCX_SKIP_PROPERTY = "databricks.labs.ucx.skip"
5159

52-
def __init__(self, ws: WorkspaceClient, folder: str | None = None):
60+
def __init__(self, ws: WorkspaceClient, backend: SqlBackend, folder: str | None = None):
5361
if not folder:
5462
folder = f"/Users/{ws.current_user.me().user_name}/.ucx"
5563
self._ws = ws
5664
self._folder = folder
65+
self._backend = backend
5766
self._field_names = [_.name for _ in dataclasses.fields(Rule)]
5867

5968
def current_tables(self, tables: TablesCrawler, workspace_name: str, catalog_name: str):
@@ -75,11 +84,6 @@ def save(self, tables: TablesCrawler, workspace_info: WorkspaceInfo) -> str:
7584
buffer.seek(0)
7685
return self._overwrite_mapping(buffer)
7786

78-
def _overwrite_mapping(self, buffer) -> str:
79-
path = f"{self._folder}/mapping.csv"
80-
self._ws.workspace.upload(path, buffer, overwrite=True, format=ImportFormat.AUTO)
81-
return path
82-
8387
def load(self) -> list[Rule]:
8488
try:
8589
rules = []
@@ -91,10 +95,12 @@ def load(self) -> list[Rule]:
9195
msg = "Please run: databricks labs ucx table-mapping"
9296
raise ValueError(msg) from None
9397

94-
def skip_table(self, backend: StatementExecutionBackend, schema: str, table: str):
98+
def skip_table(self, schema: str, table: str):
9599
# Marks a table to be skipped in the migration process by applying a table property
96100
try:
97-
backend.execute(f"ALTER TABLE `{schema}`.`{table}` SET TBLPROPERTIES('{self.UCX_SKIP_PROPERTY}' = true)")
101+
self._backend.execute(
102+
f"ALTER TABLE `{schema}`.`{table}` SET TBLPROPERTIES('{self.UCX_SKIP_PROPERTY}' = true)"
103+
)
98104
except NotFound as nf:
99105
if "[TABLE_OR_VIEW_NOT_FOUND]" in str(nf):
100106
logger.error(f"Failed to apply skip marker for Table {schema}.{table}. Table not found.")
@@ -103,14 +109,96 @@ def skip_table(self, backend: StatementExecutionBackend, schema: str, table: str
103109
except BadRequest as br:
104110
logger.error(br)
105111

106-
def skip_schema(self, backend: StatementExecutionBackend, schema: str):
112+
def skip_schema(self, schema: str):
107113
# Marks a schema to be skipped in the migration process by applying a table property
108114
try:
109-
backend.execute(f"ALTER SCHEMA `{schema}` SET DBPROPERTIES('{self.UCX_SKIP_PROPERTY}' = true)")
115+
self._backend.execute(f"ALTER SCHEMA `{schema}` SET DBPROPERTIES('{self.UCX_SKIP_PROPERTY}' = true)")
110116
except NotFound as nf:
111117
if "[SCHEMA_NOT_FOUND]" in str(nf):
112118
logger.error(f"Failed to apply skip marker for Schema {schema}. Schema not found.")
113119
else:
114120
logger.error(nf)
115121
except BadRequest as br:
116122
logger.error(br)
123+
124+
def get_tables_to_migrate(self, tables_crawler: TablesCrawler):
125+
rules = self.load()
126+
# Getting all the source tables from the rules
127+
databases_in_scope = self._get_databases_in_scope({rule.src_schema for rule in rules})
128+
crawled_tables_keys = {crawled_table.key: crawled_table for crawled_table in tables_crawler.snapshot()}
129+
tasks = []
130+
for rule in rules:
131+
if rule.as_hms_table_key not in crawled_tables_keys:
132+
logger.info(f"Table {rule.as_hms_table_key} in the mapping doesn't show up in assessment")
133+
continue
134+
if rule.src_schema not in databases_in_scope:
135+
logger.info(f"Table {rule.as_hms_table_key} is in a database that was marked to be skipped")
136+
continue
137+
tasks.append(
138+
partial(self._get_table_in_scope_task, TableToMigrate(crawled_tables_keys[rule.as_hms_table_key], rule))
139+
)
140+
141+
return Threads.strict("checking all database properties", tasks)
142+
143+
def _overwrite_mapping(self, buffer) -> str:
144+
path = f"{self._folder}/mapping.csv"
145+
self._ws.workspace.upload(path, buffer, overwrite=True, format=ImportFormat.AUTO)
146+
return path
147+
148+
def _get_databases_in_scope(self, databases: set[str]):
149+
tasks = []
150+
for database in databases:
151+
tasks.append(partial(self._get_database_in_scope_task, database))
152+
return Threads.strict("checking databases for skip property", tasks)
153+
154+
def _get_database_in_scope_task(self, database: str) -> str | None:
155+
describe = {}
156+
for value in self._backend.fetch(f"DESCRIBE SCHEMA EXTENDED {database}"):
157+
describe[value["database_description_item"]] = value["database_description_value"]
158+
if self.UCX_SKIP_PROPERTY in TablesCrawler.parse_database_props(describe.get("Properties", "").lower()):
159+
logger.info(f"Database {database} is marked to be skipped")
160+
return None
161+
return database
162+
163+
def _get_table_in_scope_task(self, table_to_migrate: TableToMigrate) -> TableToMigrate | None:
164+
table = table_to_migrate.src
165+
rule = table_to_migrate.rule
166+
167+
if self._exists_in_uc(table, rule.as_uc_table_key):
168+
logger.info(f"The intended target for {table.key}, {rule.as_uc_table_key}, already exists.")
169+
return None
170+
result = self._backend.fetch(f"SHOW TBLPROPERTIES `{table.database}`.`{table.name}`")
171+
for value in result:
172+
if value["key"] == self.UCX_SKIP_PROPERTY:
173+
logger.info(f"{table.key} is marked to be skipped")
174+
return None
175+
if value["key"] == "upgraded_to":
176+
logger.info(f"{table.key} is set as upgraded to {value['value']}")
177+
if self._exists_in_uc(table, value["value"]):
178+
logger.info(
179+
f"The table {table.key} was previously upgraded to {value['value']}. "
180+
f"To revert the table and allow it to be upgraded again use the CLI command:"
181+
f"databricks labs ucx revert --schema {table.database} --table {table.name}"
182+
)
183+
return None
184+
logger.info(f"The upgrade_to target for {table.key} is missing. Unsetting the upgrade_to property")
185+
self._backend.execute(table.sql_unset_upgraded_to())
186+
187+
return table_to_migrate
188+
189+
def _exists_in_uc(self, src_table: Table, target_key: str):
190+
# Attempts to get the target table info from UC returns True if it exists.
191+
try:
192+
table_info = self._ws.tables.get(target_key)
193+
if not table_info.properties:
194+
return True
195+
upgraded_from = table_info.properties.get("upgraded_from")
196+
if upgraded_from and upgraded_from != src_table.key:
197+
msg = f"Expected to be migrated from {src_table.key}, but got {upgraded_from}. "
198+
"You can skip this error using the CLI command: "
199+
"databricks labs ucx skip "
200+
f"--schema {src_table.database} --table {src_table.name}"
201+
raise ResourceConflict(msg)
202+
return True
203+
except NotFound:
204+
return False

src/databricks/labs/ucx/hive_metastore/table_migrate.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,10 @@ def __init__(
2929

3030
def migrate_tables(self):
3131
self._init_seen_tables()
32-
mapping_rules = self._get_mapping_rules()
32+
tables_to_migrate = self._tm.get_tables_to_migrate(self._tc)
3333
tasks = []
34-
for table in self._tc.snapshot():
35-
rule = mapping_rules.get(table.key)
36-
if not rule:
37-
logger.info(f"Skipping table {table.key} table doesn't exist in the mapping table.")
38-
continue
39-
tasks.append(partial(self._migrate_table, table, rule))
34+
for table in tables_to_migrate:
35+
tasks.append(partial(self._migrate_table, table.src, table.rule))
4036
Threads.strict("migrate tables", tasks)
4137

4238
def _migrate_table(self, src_table: Table, rule: Rule):
@@ -188,9 +184,3 @@ def print_revert_report(self, *, delete_managed: bool) -> bool | None:
188184
print("Migrated Manged Tables (targets) will be left intact.")
189185
print("To revert and delete Migrated Tables, add --delete_managed true flag to the command.")
190186
return True
191-
192-
def _get_mapping_rules(self) -> dict[str, Rule]:
193-
mapping_rules: dict[str, Rule] = {}
194-
for rule in self._tm.load():
195-
mapping_rules[rule.as_hms_table_key] = rule
196-
return mapping_rules

src/databricks/labs/ucx/hive_metastore/tables.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,13 @@ def _parse_table_props(tbl_props: str) -> dict:
140140
# Convert key-value pairs to dictionary
141141
return dict(key_value_pairs)
142142

143+
@staticmethod
144+
def parse_database_props(tbl_props: str) -> dict:
145+
pattern = r"([^,^\(^\)\[\]]+),([^,^\(^\)\[\]]+)"
146+
key_value_pairs = re.findall(pattern, tbl_props)
147+
# Convert key-value pairs to dictionary
148+
return dict(key_value_pairs)
149+
143150
def _try_load(self) -> Iterable[Table]:
144151
"""Tries to load table information from the database or throws TABLE_OR_VIEW_NOT_FOUND error"""
145152
for row in self._fetch(f"SELECT * FROM {self._full_name}"):

tests/integration/conftest.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66

77
import databricks.sdk.core
88
import pytest
9-
from databricks.sdk import AccountClient
9+
from databricks.sdk import AccountClient, WorkspaceClient
1010
from databricks.sdk.core import Config
1111
from databricks.sdk.errors import NotFound
1212
from databricks.sdk.retries import retried
1313
from databricks.sdk.service.catalog import TableInfo
1414

1515
from databricks.labs.ucx.framework.crawlers import SqlBackend
1616
from databricks.labs.ucx.hive_metastore import TablesCrawler
17-
from databricks.labs.ucx.hive_metastore.mapping import Rule
17+
from databricks.labs.ucx.hive_metastore.mapping import Rule, TableMapping
1818
from databricks.labs.ucx.hive_metastore.tables import Table
1919
from databricks.labs.ucx.mixins.fixtures import * # noqa: F403
2020
from databricks.labs.ucx.workspace_access.groups import MigratedGroup
@@ -24,7 +24,6 @@
2424

2525
logger = logging.getLogger(__name__)
2626

27-
2827
retry_on_not_found = functools.partial(retried, on=[NotFound], timeout=timedelta(minutes=5))
2928
long_retry_on_not_found = functools.partial(retry_on_not_found, timeout=timedelta(minutes=15))
3029

@@ -128,7 +127,7 @@ def __init__(self, sql_backend: SqlBackend, schema: str, tables: list[TableInfo]
128127
object_type=f"{_.table_type.value}",
129128
view_text=_.view_definition,
130129
location=_.storage_location,
131-
table_format=f"{ _.data_source_format.value}" if _.table_type.value != "VIEW" else None, # type: ignore[arg-type]
130+
table_format=f"{_.data_source_format.value}" if _.table_type.value != "VIEW" else None, # type: ignore[arg-type]
132131
)
133132
for _ in tables
134133
]
@@ -137,9 +136,12 @@ def snapshot(self) -> list[Table]:
137136
return self._tables
138137

139138

140-
class StaticTableMapping:
141-
def __init__(self, rules: list[Rule] | None = None):
139+
class StaticTableMapping(TableMapping):
140+
def __init__(
141+
self, ws: WorkspaceClient, backend: SqlBackend, folder: str | None = None, rules: list[Rule] | None = None
142+
):
142143
self._rules = rules
144+
super().__init__(ws, backend, folder)
143145

144146
def load(self):
145147
return self._rules

0 commit comments

Comments
 (0)