Skip to content

Scoping "Seen Tables". Improving table scan performance. #3741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
6 changes: 3 additions & 3 deletions src/databricks/labs/ucx/hive_metastore/table_migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def migrate_tables(
if what in [What.DB_DATASET, What.UNKNOWN]:
logger.error(f"Can't migrate tables with type {what.name}")
return None
self._init_seen_tables()
if what == What.VIEW:
return self._migrate_views()
return self._migrate_tables(
Expand All @@ -124,6 +123,7 @@ def _migrate_tables(
):
tables_to_migrate = self._table_mapping.get_tables_to_migrate(self._tables_crawler, check_uc_table)
tables_in_scope = filter(lambda t: t.src.what == what, tables_to_migrate)
self._init_seen_tables(scope={t.rule.as_uc_table for t in tables_to_migrate})
tasks = []
for table in tables_in_scope:
tasks.append(
Expand Down Expand Up @@ -597,8 +597,8 @@ def print_revert_report(
print("To revert and delete Migrated Tables, add --delete_managed true flag to the command")
return True

def _init_seen_tables(self):
self._seen_tables = self._migration_status_refresher.get_seen_tables()
def _init_seen_tables(self, *, scope: set[Table] | None = None):
self._seen_tables = self._migration_status_refresher.get_seen_tables(scope=scope)

def _sql_alter_to(self, table: Table, target_table_key: str):
return f"ALTER {table.kind} {escape_sql_identifier(table.key)} SET TBLPROPERTIES ('upgraded_to' = '{target_table_key}');"
Expand Down
78 changes: 57 additions & 21 deletions src/databricks/labs/ucx/hive_metastore/table_migration_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
import logging
from dataclasses import dataclass, replace
from collections.abc import Iterable, KeysView
from functools import partial
from typing import ClassVar

from databricks.labs.blueprint.parallel import Threads
from databricks.labs.lsql.backends import SqlBackend
from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import DatabricksError, NotFound
from databricks.sdk.service.catalog import CatalogInfo, CatalogInfoSecurableKind, SchemaInfo
from databricks.sdk.service.catalog import CatalogInfo, CatalogInfoSecurableKind, SchemaInfo, TableInfo, CatalogType

from databricks.labs.ucx.framework.crawlers import CrawlerBase
from databricks.labs.ucx.framework.utils import escape_sql_identifier
from databricks.labs.ucx.hive_metastore.tables import TablesCrawler
from databricks.labs.ucx.hive_metastore.tables import TablesCrawler, Table

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -84,6 +86,8 @@ class TableMigrationStatusRefresher(CrawlerBase[TableMigrationStatus]):
CatalogInfoSecurableKind.CATALOG_SYSTEM,
]

API_LIMIT: ClassVar[int] = 25

def __init__(self, ws: WorkspaceClient, sql_backend: SqlBackend, schema, tables_crawler: TablesCrawler):
super().__init__(sql_backend, "hive_metastore", schema, "migration_status", TableMigrationStatus)
self._ws = ws
Expand All @@ -92,29 +96,45 @@ def __init__(self, ws: WorkspaceClient, sql_backend: SqlBackend, schema, tables_
def index(self, *, force_refresh: bool = False) -> TableMigrationIndex:
return TableMigrationIndex(self.snapshot(force_refresh=force_refresh))

def get_seen_tables(self) -> dict[str, str]:
def get_seen_tables(self, *, scope: set[Table] | None = None) -> dict[str, str]:
seen_tables: dict[str, str] = {}
tasks = []
schema_scope = {(table.catalog.lower(), table.database.lower()) for table in scope} if scope else None
table_scope = {table.full_name.lower() for table in scope} if scope else None
for schema in self._iter_schemas():
if schema.catalog_name is None or schema.name is None:
if (
schema.catalog_name is None
or schema.name is None
or schema.name == "information_schema"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this into _iter_schema

or (schema_scope and (schema.catalog_name.lower(), schema.name.lower()) not in schema_scope)
):
continue
try:
# ws.tables.list returns Iterator[TableInfo], so we need to convert it to a list in order to catch the exception
tables = list(self._ws.tables.list(catalog_name=schema.catalog_name, schema_name=schema.name))
except NotFound:
logger.warning(f"Schema {schema.full_name} no longer exists. Skipping checking its migration status.")
tasks.append(
partial(
self._iter_tables,
schema.catalog_name,
schema.name,
)
)
tables: list = []
logger.info(f"Scanning {len(tasks)} schemas for tables")
table_lists = Threads.gather("list tables", tasks, self.API_LIMIT)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not an API limit, but the number of threads

# Combine tuple of lists to a list
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use itertools.chain instead

for table_list in table_lists[0]:
tables.extend(table_list)
for table in tables:
if not isinstance(table, TableInfo):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to happen? It should not based on the type hinting, right?

logger.warning(f"Table {table} is not an instance of TableInfo")
continue
except DatabricksError as e:
logger.warning(f"Error while listing tables in schema: {schema.full_name}", exc_info=e)
if not table.full_name:
logger.warning(f"The table {table.name} in {table.schema_name} has no full name")
continue
for table in tables:
if not table.properties:
continue
if "upgraded_from" not in table.properties:
continue
if not table.full_name:
logger.warning(f"The table {table.name} in {schema.name} has no full name")
continue
seen_tables[table.full_name.lower()] = table.properties["upgraded_from"].lower()
if table_scope and table.full_name.lower() not in table_scope:
continue
if not table.properties or "upgraded_from" not in table.properties:
continue

seen_tables[table.full_name.lower()] = table.properties["upgraded_from"].lower()
return seen_tables

def is_migrated(self, schema: str, table: str) -> bool:
Expand Down Expand Up @@ -171,7 +191,10 @@ def _iter_catalogs(self) -> Iterable[CatalogInfo]:

def _iter_schemas(self) -> Iterable[SchemaInfo]:
for catalog in self._iter_catalogs():
if catalog.name is None:
if catalog.name is None or catalog.catalog_type in (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has probably the same effect as the filter on line 186. If not, move the filter there so that filtering catalogs happens inside the _iter_catalog

CatalogType.DELTASHARING_CATALOG,
CatalogType.SYSTEM_CATALOG,
):
continue
try:
yield from self._ws.schemas.list(catalog_name=catalog.name)
Expand All @@ -181,3 +204,16 @@ def _iter_schemas(self) -> Iterable[SchemaInfo]:
except DatabricksError as e:
logger.warning(f"Error while listing schemas in catalog: {catalog.name}", exc_info=e)
continue

def _iter_tables(self, catalog_name: str, schema_name: str) -> list[TableInfo]:
try:
# ws.tables.list returns Iterator[TableInfo], so we need to convert it to a list in order to catch the exception
return list(self._ws.tables.list(catalog_name=catalog_name, schema_name=schema_name))
except NotFound:
logger.warning(
f"Schema {catalog_name}.{schema_name} no longer exists. Skipping checking its migration status."
)
return []
except DatabricksError as e:
logger.warning(f"Error while listing tables in schema: {schema_name}", exc_info=e)
return []
35 changes: 26 additions & 9 deletions tests/unit/hive_metastore/test_table_migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def test_migrate_already_upgraded_table_should_produce_no_queries(ws, mock_pyspa
table_crawler = TablesCrawler(crawler_backend, "inventory_database")
ws.catalogs.list.return_value = [CatalogInfo(name="cat1")]
ws.schemas.list.return_value = [
SchemaInfo(catalog_name="cat1", name="test_schema1"),
SchemaInfo(catalog_name="cat1", name="schema1"),
]
ws.tables.list.return_value = [
TableInfo(
Expand Down Expand Up @@ -1043,15 +1043,24 @@ def test_table_status_seen_tables(caplog):
table_crawler = create_autospec(TablesCrawler)
client = create_autospec(WorkspaceClient)
client.catalogs.list.return_value = [CatalogInfo(name="cat1"), CatalogInfo(name="deleted_cat")]
client.schemas.list.side_effect = [
[
schemas = {
"cat1": [
SchemaInfo(catalog_name="cat1", name="schema1", full_name="cat1.schema1"),
SchemaInfo(catalog_name="cat1", name="deleted_schema", full_name="cat1.deleted_schema"),
],
NotFound(),
]
client.tables.list.side_effect = [
[
"deleted_cat": None,
}

def schema_list(catalog_name):
schema = schemas[catalog_name]
if not schema:
raise NotFound()
return schema

client.schemas.list = schema_list

tables = {
("cat1", "schema1"): [
TableInfo(
catalog_name="cat1",
schema_name="schema1",
Expand Down Expand Up @@ -1086,8 +1095,16 @@ def test_table_status_seen_tables(caplog):
properties={"upgraded_from": "hive_metastore.schema1.table2"},
),
],
NotFound(),
]
("cat1", "deleted_schema"): None,
}

def table_list(catalog_name, schema_name):
table = tables[(catalog_name, schema_name)]
if not table:
raise NotFound()
return table

client.tables.list = table_list
table_status_crawler = TableMigrationStatusRefresher(client, backend, "ucx", table_crawler)
seen_tables = table_status_crawler.get_seen_tables()
assert seen_tables == {
Expand Down
53 changes: 52 additions & 1 deletion tests/unit/hive_metastore/test_table_migration_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from databricks.sdk.errors import BadRequest, DatabricksError, NotFound
from databricks.sdk.service.catalog import CatalogInfoSecurableKind, CatalogInfo, SchemaInfo, TableInfo

from databricks.labs.ucx.hive_metastore.tables import TablesCrawler
from databricks.labs.ucx.hive_metastore.tables import TablesCrawler, Table
from databricks.labs.ucx.hive_metastore.table_migration_status import TableMigrationStatusRefresher


Expand Down Expand Up @@ -119,3 +119,54 @@ def tables_list(catalog_name: str, schema_name: str) -> Iterable[TableInfo]:
ws.schemas.list.assert_called_once_with(catalog_name="test") # System is NOT called
ws.tables.list.assert_called()
tables_crawler.snapshot.assert_not_called()


def test_table_migration_status_refresher_scope(mock_backend) -> None:
ws = create_autospec(WorkspaceClient)
ws.catalogs.list.return_value = [
CatalogInfo(name="test1"),
CatalogInfo(name="test2"),
]

def schemas_list(catalog_name: str) -> Iterable[SchemaInfo]:
schemas = [
SchemaInfo(catalog_name="test1", name="test1"),
SchemaInfo(catalog_name="test2", name="test2"),
]
for schema in schemas:
if schema.catalog_name == catalog_name:
yield schema

def tables_list(catalog_name: str, schema_name: str) -> Iterable[TableInfo]:
tables = [
TableInfo(
full_name="test1.test1.test1",
catalog_name="test1",
schema_name="test1",
name="test1",
properties={"upgraded_from": "test1"},
),
TableInfo(
full_name="test2.test2.test2",
catalog_name="test2",
schema_name="test2",
name="test2",
properties={"upgraded_from": "test2"},
),
]
for table in tables:
if table.catalog_name == catalog_name and table.schema_name == schema_name:
yield table

ws.schemas.list.side_effect = schemas_list
ws.tables.list.side_effect = tables_list
tables_crawler = create_autospec(TablesCrawler)
refresher = TableMigrationStatusRefresher(ws, mock_backend, "test", tables_crawler)

scope_table = Table("test1", "test1", "test1", "Table", "Delta")
# Test with scope
assert refresher.get_seen_tables(scope={scope_table}) == {"test1.test1.test1": "test1"}
# Test without scope
assert refresher.get_seen_tables() == {"test1.test1.test1": "test1", "test2.test2.test2": "test2"}
ws.tables.list.assert_called()
tables_crawler.snapshot.assert_not_called()
Loading