diff --git a/src/databricks/labs/ucx/hive_metastore/table_migrate.py b/src/databricks/labs/ucx/hive_metastore/table_migrate.py index d58e6b8da4..732a223f77 100644 --- a/src/databricks/labs/ucx/hive_metastore/table_migrate.py +++ b/src/databricks/labs/ucx/hive_metastore/table_migrate.py @@ -108,7 +108,6 @@ def migrate_tables( if what in [What.DB_DATASET, What.UNKNOWN]: logger.error(f"Can't migrate tables with type {what.name}") return None - self._init_seen_tables() if what == What.VIEW: return self._migrate_views() return self._migrate_tables( @@ -124,6 +123,7 @@ def _migrate_tables( ): tables_to_migrate = self._table_mapping.get_tables_to_migrate(self._tables_crawler, check_uc_table) tables_in_scope = filter(lambda t: t.src.what == what, tables_to_migrate) + self._init_seen_tables(scope={t.rule.as_uc_table for t in tables_to_migrate}) tasks = [] for table in tables_in_scope: tasks.append( @@ -597,8 +597,8 @@ def print_revert_report( print("To revert and delete Migrated Tables, add --delete_managed true flag to the command") return True - def _init_seen_tables(self): - self._seen_tables = self._migration_status_refresher.get_seen_tables() + def _init_seen_tables(self, *, scope: set[Table] | None = None): + self._seen_tables = self._migration_status_refresher.get_seen_tables(scope=scope) def _sql_alter_to(self, table: Table, target_table_key: str): return f"ALTER {table.kind} {escape_sql_identifier(table.key)} SET TBLPROPERTIES ('upgraded_to' = '{target_table_key}');" diff --git a/src/databricks/labs/ucx/hive_metastore/table_migration_status.py b/src/databricks/labs/ucx/hive_metastore/table_migration_status.py index 8b766755fe..d307986dd6 100644 --- a/src/databricks/labs/ucx/hive_metastore/table_migration_status.py +++ b/src/databricks/labs/ucx/hive_metastore/table_migration_status.py @@ -2,16 +2,18 @@ import logging from dataclasses import dataclass, replace from collections.abc import Iterable, KeysView +from functools import partial from typing import ClassVar +from databricks.labs.blueprint.parallel import Threads from databricks.labs.lsql.backends import SqlBackend from databricks.sdk import WorkspaceClient from databricks.sdk.errors import DatabricksError, NotFound -from databricks.sdk.service.catalog import CatalogInfo, CatalogInfoSecurableKind, SchemaInfo +from databricks.sdk.service.catalog import CatalogInfo, CatalogInfoSecurableKind, SchemaInfo, TableInfo, CatalogType from databricks.labs.ucx.framework.crawlers import CrawlerBase from databricks.labs.ucx.framework.utils import escape_sql_identifier -from databricks.labs.ucx.hive_metastore.tables import TablesCrawler +from databricks.labs.ucx.hive_metastore.tables import TablesCrawler, Table logger = logging.getLogger(__name__) @@ -84,6 +86,8 @@ class TableMigrationStatusRefresher(CrawlerBase[TableMigrationStatus]): CatalogInfoSecurableKind.CATALOG_SYSTEM, ] + API_LIMIT: ClassVar[int] = 25 + def __init__(self, ws: WorkspaceClient, sql_backend: SqlBackend, schema, tables_crawler: TablesCrawler): super().__init__(sql_backend, "hive_metastore", schema, "migration_status", TableMigrationStatus) self._ws = ws @@ -92,29 +96,45 @@ def __init__(self, ws: WorkspaceClient, sql_backend: SqlBackend, schema, tables_ def index(self, *, force_refresh: bool = False) -> TableMigrationIndex: return TableMigrationIndex(self.snapshot(force_refresh=force_refresh)) - def get_seen_tables(self) -> dict[str, str]: + def get_seen_tables(self, *, scope: set[Table] | None = None) -> dict[str, str]: seen_tables: dict[str, str] = {} + tasks = [] + schema_scope = {(table.catalog.lower(), table.database.lower()) for table in scope} if scope else None + table_scope = {table.full_name.lower() for table in scope} if scope else None for schema in self._iter_schemas(): - if schema.catalog_name is None or schema.name is None: + if ( + schema.catalog_name is None + or schema.name is None + or schema.name == "information_schema" + or (schema_scope and (schema.catalog_name.lower(), schema.name.lower()) not in schema_scope) + ): continue - try: - # ws.tables.list returns Iterator[TableInfo], so we need to convert it to a list in order to catch the exception - tables = list(self._ws.tables.list(catalog_name=schema.catalog_name, schema_name=schema.name)) - except NotFound: - logger.warning(f"Schema {schema.full_name} no longer exists. Skipping checking its migration status.") + tasks.append( + partial( + self._iter_tables, + schema.catalog_name, + schema.name, + ) + ) + tables: list = [] + logger.info(f"Scanning {len(tasks)} schemas for tables") + table_lists = Threads.gather("list tables", tasks, self.API_LIMIT) + # Combine tuple of lists to a list + for table_list in table_lists[0]: + tables.extend(table_list) + for table in tables: + if not isinstance(table, TableInfo): + logger.warning(f"Table {table} is not an instance of TableInfo") continue - except DatabricksError as e: - logger.warning(f"Error while listing tables in schema: {schema.full_name}", exc_info=e) + if not table.full_name: + logger.warning(f"The table {table.name} in {table.schema_name} has no full name") continue - for table in tables: - if not table.properties: - continue - if "upgraded_from" not in table.properties: - continue - if not table.full_name: - logger.warning(f"The table {table.name} in {schema.name} has no full name") - continue - seen_tables[table.full_name.lower()] = table.properties["upgraded_from"].lower() + if table_scope and table.full_name.lower() not in table_scope: + continue + if not table.properties or "upgraded_from" not in table.properties: + continue + + seen_tables[table.full_name.lower()] = table.properties["upgraded_from"].lower() return seen_tables def is_migrated(self, schema: str, table: str) -> bool: @@ -171,7 +191,10 @@ def _iter_catalogs(self) -> Iterable[CatalogInfo]: def _iter_schemas(self) -> Iterable[SchemaInfo]: for catalog in self._iter_catalogs(): - if catalog.name is None: + if catalog.name is None or catalog.catalog_type in ( + CatalogType.DELTASHARING_CATALOG, + CatalogType.SYSTEM_CATALOG, + ): continue try: yield from self._ws.schemas.list(catalog_name=catalog.name) @@ -181,3 +204,16 @@ def _iter_schemas(self) -> Iterable[SchemaInfo]: except DatabricksError as e: logger.warning(f"Error while listing schemas in catalog: {catalog.name}", exc_info=e) continue + + def _iter_tables(self, catalog_name: str, schema_name: str) -> list[TableInfo]: + try: + # ws.tables.list returns Iterator[TableInfo], so we need to convert it to a list in order to catch the exception + return list(self._ws.tables.list(catalog_name=catalog_name, schema_name=schema_name)) + except NotFound: + logger.warning( + f"Schema {catalog_name}.{schema_name} no longer exists. Skipping checking its migration status." + ) + return [] + except DatabricksError as e: + logger.warning(f"Error while listing tables in schema: {schema_name}", exc_info=e) + return [] diff --git a/tests/unit/hive_metastore/test_table_migrate.py b/tests/unit/hive_metastore/test_table_migrate.py index 1712852585..db5996d79c 100644 --- a/tests/unit/hive_metastore/test_table_migrate.py +++ b/tests/unit/hive_metastore/test_table_migrate.py @@ -498,7 +498,7 @@ def test_migrate_already_upgraded_table_should_produce_no_queries(ws, mock_pyspa table_crawler = TablesCrawler(crawler_backend, "inventory_database") ws.catalogs.list.return_value = [CatalogInfo(name="cat1")] ws.schemas.list.return_value = [ - SchemaInfo(catalog_name="cat1", name="test_schema1"), + SchemaInfo(catalog_name="cat1", name="schema1"), ] ws.tables.list.return_value = [ TableInfo( @@ -1043,15 +1043,24 @@ def test_table_status_seen_tables(caplog): table_crawler = create_autospec(TablesCrawler) client = create_autospec(WorkspaceClient) client.catalogs.list.return_value = [CatalogInfo(name="cat1"), CatalogInfo(name="deleted_cat")] - client.schemas.list.side_effect = [ - [ + schemas = { + "cat1": [ SchemaInfo(catalog_name="cat1", name="schema1", full_name="cat1.schema1"), SchemaInfo(catalog_name="cat1", name="deleted_schema", full_name="cat1.deleted_schema"), ], - NotFound(), - ] - client.tables.list.side_effect = [ - [ + "deleted_cat": None, + } + + def schema_list(catalog_name): + schema = schemas[catalog_name] + if not schema: + raise NotFound() + return schema + + client.schemas.list = schema_list + + tables = { + ("cat1", "schema1"): [ TableInfo( catalog_name="cat1", schema_name="schema1", @@ -1086,8 +1095,16 @@ def test_table_status_seen_tables(caplog): properties={"upgraded_from": "hive_metastore.schema1.table2"}, ), ], - NotFound(), - ] + ("cat1", "deleted_schema"): None, + } + + def table_list(catalog_name, schema_name): + table = tables[(catalog_name, schema_name)] + if not table: + raise NotFound() + return table + + client.tables.list = table_list table_status_crawler = TableMigrationStatusRefresher(client, backend, "ucx", table_crawler) seen_tables = table_status_crawler.get_seen_tables() assert seen_tables == { diff --git a/tests/unit/hive_metastore/test_table_migration_status.py b/tests/unit/hive_metastore/test_table_migration_status.py index ad9350e92c..1203fd12b3 100644 --- a/tests/unit/hive_metastore/test_table_migration_status.py +++ b/tests/unit/hive_metastore/test_table_migration_status.py @@ -6,7 +6,7 @@ from databricks.sdk.errors import BadRequest, DatabricksError, NotFound from databricks.sdk.service.catalog import CatalogInfoSecurableKind, CatalogInfo, SchemaInfo, TableInfo -from databricks.labs.ucx.hive_metastore.tables import TablesCrawler +from databricks.labs.ucx.hive_metastore.tables import TablesCrawler, Table from databricks.labs.ucx.hive_metastore.table_migration_status import TableMigrationStatusRefresher @@ -119,3 +119,54 @@ def tables_list(catalog_name: str, schema_name: str) -> Iterable[TableInfo]: ws.schemas.list.assert_called_once_with(catalog_name="test") # System is NOT called ws.tables.list.assert_called() tables_crawler.snapshot.assert_not_called() + + +def test_table_migration_status_refresher_scope(mock_backend) -> None: + ws = create_autospec(WorkspaceClient) + ws.catalogs.list.return_value = [ + CatalogInfo(name="test1"), + CatalogInfo(name="test2"), + ] + + def schemas_list(catalog_name: str) -> Iterable[SchemaInfo]: + schemas = [ + SchemaInfo(catalog_name="test1", name="test1"), + SchemaInfo(catalog_name="test2", name="test2"), + ] + for schema in schemas: + if schema.catalog_name == catalog_name: + yield schema + + def tables_list(catalog_name: str, schema_name: str) -> Iterable[TableInfo]: + tables = [ + TableInfo( + full_name="test1.test1.test1", + catalog_name="test1", + schema_name="test1", + name="test1", + properties={"upgraded_from": "test1"}, + ), + TableInfo( + full_name="test2.test2.test2", + catalog_name="test2", + schema_name="test2", + name="test2", + properties={"upgraded_from": "test2"}, + ), + ] + for table in tables: + if table.catalog_name == catalog_name and table.schema_name == schema_name: + yield table + + ws.schemas.list.side_effect = schemas_list + ws.tables.list.side_effect = tables_list + tables_crawler = create_autospec(TablesCrawler) + refresher = TableMigrationStatusRefresher(ws, mock_backend, "test", tables_crawler) + + scope_table = Table("test1", "test1", "test1", "Table", "Delta") + # Test with scope + assert refresher.get_seen_tables(scope={scope_table}) == {"test1.test1.test1": "test1"} + # Test without scope + assert refresher.get_seen_tables() == {"test1.test1.test1": "test1", "test2.test2.test2": "test2"} + ws.tables.list.assert_called() + tables_crawler.snapshot.assert_not_called()