Skip to content

Commit d2d0179

Browse files
authored
Ensure that USE statements are recognized and apply to table references without a qualifying schema in SQL and pyspark (#1433)
1 parent fd645fd commit d2d0179

File tree

14 files changed

+356
-49
lines changed

14 files changed

+356
-49
lines changed

src/databricks/labs/ucx/hive_metastore/view_migrate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex, TableView
1010
from databricks.labs.ucx.hive_metastore.mapping import TableToMigrate
11+
from databricks.labs.ucx.source_code.base import CurrentSessionState
1112
from databricks.labs.ucx.source_code.queries import FromTable
1213

1314
logger = logging.getLogger(__name__)
@@ -41,7 +42,7 @@ def _view_dependencies(self):
4142
yield TableView("hive_metastore", src_db, old_table.name)
4243

4344
def sql_migrate_view(self, index: MigrationIndex) -> str:
44-
from_table = FromTable(index, use_schema=self.src.database)
45+
from_table = FromTable(index, CurrentSessionState(self.src.database))
4546
assert self.src.view_text is not None, 'Expected a view text'
4647
migrated_select = from_table.apply(self.src.view_text)
4748
statements = sqlglot.parse(migrated_select, read='databricks')

src/databricks/labs/ucx/source_code/base.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,25 @@ def name(self) -> str: ...
8282
def apply(self, code: str) -> str: ...
8383

8484

85+
# The default schema to use when the schema is not specified in a table reference
86+
# See: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-qry-select-usedb.html
87+
DEFAULT_SCHEMA = 'default'
88+
89+
90+
@dataclass
91+
class CurrentSessionState:
92+
"""
93+
A data class that represents the current state of a session.
94+
95+
This class can be used to track various aspects of a session, such as the current schema.
96+
97+
Attributes:
98+
schema (str): The current schema of the session. If not provided, it defaults to 'DEFAULT_SCHEMA'.
99+
"""
100+
101+
schema: str = DEFAULT_SCHEMA
102+
103+
85104
class SequentialLinter(Linter):
86105
def __init__(self, linters: list[Linter]):
87106
self._linters = linters

src/databricks/labs/ucx/source_code/dbfs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def name() -> str:
9191
return 'dbfs-query'
9292

9393
def lint(self, code: str) -> Iterable[Advice]:
94-
for statement in sqlglot.parse(code, dialect='databricks'):
94+
for statement in sqlglot.parse(code, read='databricks'):
9595
if not statement:
9696
continue
9797
for table in statement.find_all(Table):

src/databricks/labs/ucx/source_code/languages.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from databricks.sdk.service.workspace import Language
22

33
from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex
4-
from databricks.labs.ucx.source_code.base import Fixer, Linter, SequentialLinter
4+
from databricks.labs.ucx.source_code.base import Fixer, Linter, SequentialLinter, CurrentSessionState
55
from databricks.labs.ucx.source_code.pyspark import SparkSql
66
from databricks.labs.ucx.source_code.queries import FromTable
77
from databricks.labs.ucx.source_code.dbfs import DBFSUsageLinter, FromDbfsFolder
@@ -11,7 +11,8 @@
1111
class Languages:
1212
def __init__(self, index: MigrationIndex):
1313
self._index = index
14-
from_table = FromTable(index)
14+
session_state = CurrentSessionState()
15+
from_table = FromTable(index, session_state=session_state)
1516
dbfs_from_folder = FromDbfsFolder()
1617
self._linters = {
1718
Language.PYTHON: SequentialLinter(

src/databricks/labs/ucx/source_code/notebook.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,8 @@ def requires_isolated_pi(self) -> str:
243243

244244
@classmethod
245245
def of_language(cls, language: Language) -> CellLanguage:
246+
# TODO: Should this not raise a ValueError if the language is not found?
247+
# It also causes a GeneratorExit exception to be raised. Maybe an explicit loop is better.
246248
return next((cl for cl in CellLanguage if cl.language == language))
247249

248250
@classmethod

src/databricks/labs/ucx/source_code/notebook_linter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections.abc import Iterable
22

3+
from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex
34
from databricks.labs.ucx.source_code.base import Advice
45
from databricks.labs.ucx.source_code.notebook import Notebook
56
from databricks.labs.ucx.source_code.languages import Languages, Language
@@ -16,7 +17,8 @@ def __init__(self, langs: Languages, notebook: Notebook):
1617
self._notebook: Notebook = notebook
1718

1819
@classmethod
19-
def from_source(cls, langs: Languages, source: str, default_language: Language) -> 'NotebookLinter':
20+
def from_source(cls, index: MigrationIndex, source: str, default_language: Language) -> 'NotebookLinter':
21+
langs = Languages(index)
2022
notebook = Notebook.parse("", source, default_language)
2123
assert notebook is not None
2224
return cls(langs, notebook)

src/databricks/labs/ucx/source_code/pyspark.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class TableNameMatcher(Matcher):
7979
def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> Iterator[Advice]:
8080
table_arg = self._get_table_arg(node)
8181
if isinstance(table_arg, ast.Constant):
82-
dst = self._find_dest(index, table_arg.value)
82+
dst = self._find_dest(index, table_arg.value, from_table.schema)
8383
if dst is not None:
8484
yield Deprecation(
8585
code='table-migrate',
@@ -104,13 +104,16 @@ def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) ->
104104
def apply(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> None:
105105
table_arg = self._get_table_arg(node)
106106
assert isinstance(table_arg, ast.Constant)
107-
dst = self._find_dest(index, table_arg.value)
107+
dst = self._find_dest(index, table_arg.value, from_table.schema)
108108
if dst is not None:
109109
table_arg.value = dst.destination()
110110

111111
@staticmethod
112-
def _find_dest(index: MigrationIndex, value: str):
112+
def _find_dest(index: MigrationIndex, value: str, schema: str):
113113
parts = value.split(".")
114+
# Ensure that unqualified table references use the current schema
115+
if len(parts) == 1:
116+
return index.get(schema, parts[0])
114117
return None if len(parts) != 2 else index.get(parts[0], parts[1])
115118

116119

src/databricks/labs/ucx/source_code/python_linter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from databricks.labs.ucx.source_code.base import Linter, Advice, Advisory
1010

11-
1211
logger = logging.getLogger(__name__)
1312

1413

src/databricks/labs/ucx/source_code/queries.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,71 @@
22

33
import logging
44
import sqlglot
5-
from sqlglot.expressions import Table, Expression
5+
from sqlglot.expressions import Table, Expression, Use
66
from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex
7-
from databricks.labs.ucx.source_code.base import Advice, Deprecation, Fixer, Linter
7+
from databricks.labs.ucx.source_code.base import Advice, Deprecation, Fixer, Linter, CurrentSessionState
88

99
logger = logging.getLogger(__name__)
1010

1111

1212
class FromTable(Linter, Fixer):
13-
def __init__(self, index: MigrationIndex, *, use_schema: str | None = None):
14-
self._index = index
15-
self._use_schema = use_schema
13+
"""Linter and Fixer for table migrations in SQL queries.
14+
15+
This class is responsible for identifying and fixing table migrations in
16+
SQL queries.
17+
"""
18+
19+
def __init__(self, index: MigrationIndex, session_state: CurrentSessionState):
20+
"""
21+
Initializes the FromTable class.
22+
23+
Args:
24+
index: The migration index, which is a mapping of source tables to destination tables.
25+
session_state: The current session state, which will be used to track the current schema.
26+
27+
We need to be careful with the nomenclature here. For instance when parsing a table reference,
28+
sqlglot uses `db` instead of `schema` to refer to the schema. The following table references
29+
show how sqlglot represents them:::
30+
31+
catalog.schema.table -> Table(catalog='catalog', db='schema', this='table')
32+
schema.table -> Table(catalog='', db='schema', this='table')
33+
table -> Table(catalog='', db='', this='table')
34+
"""
35+
self._index: MigrationIndex = index
36+
self._session_state: CurrentSessionState = session_state if session_state else CurrentSessionState()
1637

1738
def name(self) -> str:
1839
return 'table-migrate'
1940

41+
@property
42+
def schema(self):
43+
return self._session_state.schema
44+
2045
def lint(self, code: str) -> Iterable[Advice]:
21-
for statement in sqlglot.parse(code, dialect='databricks'):
46+
for statement in sqlglot.parse(code, read='databricks'):
2247
if not statement:
2348
continue
2449
for table in statement.find_all(Table):
25-
catalog = self._catalog(table)
26-
if catalog != 'hive_metastore':
50+
if isinstance(statement, Use):
51+
# Sqlglot captures the database name in the Use statement as a Table, with
52+
# the schema as the table name.
53+
self._session_state.schema = table.name
2754
continue
28-
src_db = table.db if table.db else self._use_schema
29-
if not src_db:
55+
56+
# we only migrate tables in the hive_metastore catalog
57+
if self._catalog(table) != 'hive_metastore':
58+
continue
59+
# Sqlglot uses db instead of schema, watch out for that
60+
src_schema = table.db if table.db else self._session_state.schema
61+
if not src_schema:
3062
logger.error(f"Could not determine schema for table {table.name}")
3163
continue
32-
dst = self._index.get(src_db, table.name)
64+
dst = self._index.get(src_schema, table.name)
3365
if not dst:
3466
continue
3567
yield Deprecation(
3668
code='table-migrate',
37-
message=f"Table {table.db}.{table.name} is migrated to {dst.destination()} in Unity Catalog",
69+
message=f"Table {src_schema}.{table.name} is migrated to {dst.destination()} in Unity Catalog",
3870
# SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159
3971
start_line=0,
4072
start_col=0,
@@ -53,12 +85,17 @@ def apply(self, code: str) -> str:
5385
for statement in sqlglot.parse(code, read='databricks'):
5486
if not statement:
5587
continue
88+
if isinstance(statement, Use):
89+
table = statement.this
90+
self._session_state.schema = table.name
91+
new_statements.append(statement.sql('databricks'))
92+
continue
5693
for old_table in self._dependent_tables(statement):
57-
src_db = old_table.db if old_table.db else self._use_schema
58-
if not src_db:
94+
src_schema = old_table.db if old_table.db else self._session_state.schema
95+
if not src_schema:
5996
logger.error(f"Could not determine schema for table {old_table.name}")
6097
continue
61-
dst = self._index.get(src_db, old_table.name)
98+
dst = self._index.get(src_schema, old_table.name)
6299
if not dst:
63100
continue
64101
new_table = Table(catalog=dst.dst_catalog, db=dst.dst_schema, this=dst.dst_table)

tests/unit/source_code/conftest.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,24 @@ def migration_index():
1919
MigrationStatus('other', 'matters', dst_catalog='some', dst_schema='certain', dst_table='issues'),
2020
]
2121
)
22+
23+
24+
@pytest.fixture
25+
def extended_test_index():
26+
return MigrationIndex(
27+
[
28+
MigrationStatus('old', 'things', dst_catalog='brand', dst_schema='new', dst_table='stuff'),
29+
MigrationStatus('other', 'matters', dst_catalog='some', dst_schema='certain', dst_table='issues'),
30+
MigrationStatus('old', 'stuff', dst_catalog='brand', dst_schema='new', dst_table='things'),
31+
MigrationStatus('other', 'issues', dst_catalog='some', dst_schema='certain', dst_table='matters'),
32+
MigrationStatus('default', 'testtable', dst_catalog='cata', dst_schema='nondefault', dst_table='table'),
33+
MigrationStatus('different_db', 'testtable', dst_catalog='cata2', dst_schema='newspace', dst_table='table'),
34+
MigrationStatus('old', 'testtable', dst_catalog='cata3', dst_schema='newspace', dst_table='table'),
35+
MigrationStatus('default', 'people', dst_catalog='cata4', dst_schema='nondefault', dst_table='newpeople'),
36+
MigrationStatus(
37+
'something', 'persons', dst_catalog='cata4', dst_schema='newsomething', dst_table='persons'
38+
),
39+
MigrationStatus('whatever', 'kittens', dst_catalog='cata4', dst_schema='felines', dst_table='toms'),
40+
MigrationStatus('whatever', 'numbers', dst_catalog='cata4', dst_schema='counting', dst_table='numbers'),
41+
]
42+
)

0 commit comments

Comments
 (0)