Skip to content

Commit f8bf94c

Browse files
authored
Make fixer diagnostic codes unique (#3582)
## Changes Make fixer diagnostic codes unique so that the right fixer can be found for code migration/fixing. ### Linked issues Progresses #3514 Breaks up #3520 ### Functionality - [x] modified existing command: `databricks labs ucx migrate-local-code` ### Tests - [ ] manually tested - [x] modified and added unit tests - [x] modified and added integration tests
1 parent 310d9ff commit f8bf94c

30 files changed

+334
-202
lines changed

docs/ucx/docs/dev/contributing.mdx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,9 @@ rdd-in-shared-clusters
302302
spark-logging-in-shared-clusters
303303
sql-parse-error
304304
sys-path-cannot-compute-value
305-
table-migrated-to-uc
305+
table-migrated-to-uc-python
306+
table-migrated-to-uc-python-sql
307+
table-migrated-to-uc-sql
306308
to-json-in-shared-clusters
307309
unsupported-magic-line
308310
```

docs/ucx/docs/reference/linter_codes.mdx

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ spark.table(f"foo_{some_table_name}")
3131

3232
We even detect string constants when coming either from `dbutils.widgets.get` (via job named parameters) or through
3333
loop variables. If `old.things` table is migrated to `brand.new.stuff` in Unity Catalog, the following code will
34-
trigger two messages: [`table-migrated-to-uc`](#table-migrated-to-uc) for the first query, as the contents are clearly
35-
analysable, and `cannot-autofix-table-reference` for the second query.
34+
trigger two messages: [`table-migrated-to-uc-{sql,python,python-sql}`](#table-migrated-to-uc-sqlpythonpython-sql) for
35+
the first query, as the contents are clearly analysable, and `cannot-autofix-table-reference` for the second query.
3636

3737
```python
38-
# ucx[table-migrated-to-uc:+4:4:+4:20] Table old.things is migrated to brand.new.stuff in Unity Catalog
38+
# ucx[table-migrated-to-uc-python-sql:+4:4:+4:20] Table old.things is migrated to brand.new.stuff in Unity Catalog
3939
# ucx[cannot-autofix-table-reference:+3:4:+3:20] Can't migrate table_name argument in 'spark.sql(query)' because its value cannot be computed
4040
table_name = f"table_{index}"
4141
for query in ["SELECT * FROM old.things", f"SELECT * FROM {table_name}"]:
@@ -247,12 +247,16 @@ analysis where the path is located.
247247

248248

249249

250-
## `table-migrated-to-uc`
250+
## `table-migrated-to-uc-{sql,python,python-sql}`
251251

252252
This message indicates that the linter has found a table that has been migrated to Unity Catalog. The user must ensure
253253
that the table is available in Unity Catalog.
254254

255-
255+
| Postfix | Explanation |
256+
|------------|-------------------------------------------------|
257+
| sql | Table reference in SparkSQL |
258+
| python | Table reference in PySpark |
259+
| python-sql | Table reference in SparkSQL called from PySpark |
256260

257261
## `to-json-in-shared-clusters`
258262

src/databricks/labs/ucx/source_code/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,12 @@ class Fixer(ABC):
161161

162162
@property
163163
@abstractmethod
164-
def name(self) -> str: ...
164+
def diagnostic_code(self) -> str:
165+
"""The diagnostic code that this fixer fixes."""
166+
167+
def is_supported(self, diagnostic_code: str) -> bool:
168+
"""Indicate if the diagnostic code is supported by this fixer."""
169+
return self.diagnostic_code is not None and diagnostic_code == self.diagnostic_code
165170

166171
@abstractmethod
167172
def apply(self, code: str) -> str: ...

src/databricks/labs/ucx/source_code/linters/context.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
from databricks.labs.ucx.source_code.linters.imports import DbutilsPyLinter
2525

2626
from databricks.labs.ucx.source_code.linters.pyspark import (
27-
SparkSqlPyLinter,
27+
DirectFsAccessSqlPylinter,
28+
FromTableSqlPyLinter,
2829
SparkTableNamePyLinter,
2930
SparkSqlTablePyCollector,
3031
)
@@ -57,7 +58,7 @@ def __init__(
5758
sql_linters.append(from_table)
5859
sql_fixers.append(from_table)
5960
sql_table_collectors.append(from_table)
60-
spark_sql = SparkSqlPyLinter(from_table, from_table)
61+
spark_sql = FromTableSqlPyLinter(from_table)
6162
python_linters.append(spark_sql)
6263
python_fixers.append(spark_sql)
6364
python_table_collectors.append(SparkSqlTablePyCollector(from_table))
@@ -75,7 +76,7 @@ def __init__(
7576
DBRv8d0PyLinter(dbr_version=session_state.dbr_version),
7677
SparkConnectPyLinter(session_state),
7778
DbutilsPyLinter(session_state),
78-
SparkSqlPyLinter(sql_direct_fs, None),
79+
DirectFsAccessSqlPylinter(sql_direct_fs),
7980
]
8081

8182
python_dfsa_collectors += [DirectFsAccessPyLinter(session_state, prevent_spark_duplicates=False)]
@@ -112,10 +113,16 @@ def linter(self, language: Language) -> Linter:
112113
raise ValueError(f"Unsupported language: {language}")
113114

114115
def fixer(self, language: Language, diagnostic_code: str) -> Fixer | None:
115-
if language not in self._fixers:
116-
return None
117-
for fixer in self._fixers[language]:
118-
if fixer.name == diagnostic_code:
116+
"""Get the fixer for a language that matches the code.
117+
118+
The first fixer which name matches with the diagnostic code is returned. This logic assumes the fixers have
119+
unique names.
120+
121+
Returns :
122+
Fixer | None : The fixer if a match is found, otherwise None.
123+
"""
124+
for fixer in self._fixers.get(language, []):
125+
if fixer.is_supported(diagnostic_code):
119126
return fixer
120127
return None
121128

src/databricks/labs/ucx/source_code/linters/from_table.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,9 @@ def __init__(self, index: TableMigrationIndex, session_state: CurrentSessionStat
4545
self._session_state: CurrentSessionState = session_state
4646

4747
@property
48-
def name(self) -> str:
49-
return 'table-migrate'
48+
def diagnostic_code(self) -> str:
49+
"""The diagnostic codes that this fixer fixes."""
50+
return "table-migrated-to-uc-sql"
5051

5152
@property
5253
def schema(self) -> str:
@@ -58,7 +59,7 @@ def lint_expression(self, expression: Expression) -> Iterable[Deprecation]:
5859
if not dst:
5960
return
6061
yield Deprecation(
61-
code='table-migrated-to-uc',
62+
code="table-migrated-to-uc-sql",
6263
message=f"Table {info.schema_name}.{info.table_name} is migrated to {dst.destination()} in Unity Catalog",
6364
# SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159
6465
start_line=0,

src/databricks/labs/ucx/source_code/linters/pyspark.py

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import dataclasses
12
import logging
23
from abc import ABC, abstractmethod
34
from collections.abc import Iterable, Iterator
@@ -18,7 +19,11 @@
1819
TableSqlCollector,
1920
DfsaSqlCollector,
2021
)
21-
from databricks.labs.ucx.source_code.linters.directfs import DIRECT_FS_ACCESS_PATTERNS, DirectFsAccessNode
22+
from databricks.labs.ucx.source_code.linters.directfs import (
23+
DIRECT_FS_ACCESS_PATTERNS,
24+
DirectFsAccessNode,
25+
DirectFsAccessSqlLinter,
26+
)
2227
from databricks.labs.ucx.source_code.python.python_infer import InferredValue
2328
from databricks.labs.ucx.source_code.linters.from_table import FromTableSqlLinter
2429
from databricks.labs.ucx.source_code.python.python_ast import (
@@ -155,7 +160,7 @@ def lint(
155160
if dst is None:
156161
continue
157162
yield Deprecation.from_node(
158-
code='table-migrated-to-uc',
163+
code='table-migrated-to-uc-python',
159164
message=f"Table {used_table[0]} is migrated to {dst.destination()} in Unity Catalog",
160165
# SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159
161166
node=node,
@@ -387,6 +392,14 @@ def matchers(self) -> dict[str, _TableNameMatcher]:
387392

388393

389394
class SparkTableNamePyLinter(PythonLinter, Fixer, TablePyCollector):
395+
"""Linter for table name references in PySpark
396+
397+
Examples:
398+
1. Find table name referenceS
399+
``` python
400+
spark.read.table("hive_metastore.schema.table")
401+
```
402+
"""
390403

391404
def __init__(
392405
self,
@@ -400,9 +413,9 @@ def __init__(
400413
self._spark_matchers = SparkTableNameMatchers(False).matchers
401414

402415
@property
403-
def name(self) -> str:
404-
# this is the same fixer, just in a different language context
405-
return self._from_table.name
416+
def diagnostic_code(self) -> str:
417+
"""The diagnostic codes that this fixer fixes."""
418+
return "table-migrated-to-uc-python"
406419

407420
def lint_tree(self, tree: Tree) -> Iterable[Advice]:
408421
for node in tree.walk():
@@ -461,28 +474,32 @@ def _visit_call_nodes(cls, tree: Tree) -> Iterable[tuple[Call, NodeNG]]:
461474
yield call_node, query
462475

463476

464-
class SparkSqlPyLinter(_SparkSqlAnalyzer, PythonLinter, Fixer):
477+
class _SparkSqlPyLinter(_SparkSqlAnalyzer, PythonLinter, Fixer):
478+
"""Linter for SparkSQL used within PySpark."""
465479

466480
def __init__(self, sql_linter: SqlLinter, sql_fixer: Fixer | None):
467481
self._sql_linter = sql_linter
468482
self._sql_fixer = sql_fixer
469483

470-
@property
471-
def name(self) -> str:
472-
return "<none>" if self._sql_fixer is None else self._sql_fixer.name
473-
474484
def lint_tree(self, tree: Tree) -> Iterable[Advice]:
485+
inferable_values = []
475486
for call_node, query in self._visit_call_nodes(tree):
476487
for value in InferredValue.infer_from_node(query):
477-
if not value.is_inferred():
488+
if value.is_inferred():
489+
inferable_values.append((call_node, value))
490+
else:
478491
yield Advisory.from_node(
479492
code="cannot-autofix-table-reference",
480493
message=f"Can't migrate table_name argument in '{query.as_string()}' because its value cannot be computed",
481494
node=call_node,
482495
)
483-
continue
484-
for advice in self._sql_linter.lint(value.as_string()):
485-
yield advice.replace_from_node(call_node)
496+
for call_node, value in inferable_values:
497+
for advice in self._sql_linter.lint(value.as_string()):
498+
# Replacing the fixer code to indicate that the SparkSQL fixer is wrapped with PySpark
499+
code = advice.code
500+
if self._sql_fixer and code == self._sql_fixer.diagnostic_code:
501+
code = self.diagnostic_code
502+
yield dataclasses.replace(advice.replace_from_node(call_node), code=code)
486503

487504
def apply(self, code: str) -> str:
488505
if not self._sql_fixer:
@@ -503,6 +520,45 @@ def apply(self, code: str) -> str:
503520
return tree.node.as_string()
504521

505522

523+
class FromTableSqlPyLinter(_SparkSqlPyLinter):
524+
"""Lint tables and views in Spark SQL wrapped by PySpark code.
525+
526+
Examples:
527+
1. Find table name reference in SparkSQL:
528+
``` python
529+
spark.sql("SELECT * FROM hive_metastore.schema.table").collect()
530+
```
531+
"""
532+
533+
def __init__(self, sql_linter: FromTableSqlLinter):
534+
super().__init__(sql_linter, sql_linter)
535+
536+
@property
537+
def diagnostic_code(self) -> str:
538+
"""The diagnostic codes that this fixer fixes."""
539+
return "table-migrated-to-uc-python-sql"
540+
541+
542+
class DirectFsAccessSqlPylinter(_SparkSqlPyLinter):
543+
"""Lint direct file system access in Spark SQL wrapped by PySpark code.
544+
545+
Examples:
546+
1. Find table name reference in SparkSQL:
547+
``` python
548+
spark.sql("SELECT * FROM parquet.`/dbfs/path/to/table`").collect()
549+
```
550+
"""
551+
552+
def __init__(self, sql_linter: DirectFsAccessSqlLinter):
553+
# TODO: Implement fixer for direct filesystem access (https://github.com/databrickslabs/ucx/issues/2021)
554+
super().__init__(sql_linter, None)
555+
556+
@property
557+
def diagnostic_code(self) -> str:
558+
"""The diagnostic codes that this fixer fixes."""
559+
return "direct-filesystem-access-python-sql"
560+
561+
506562
class SparkSqlDfsaPyCollector(_SparkSqlAnalyzer, DfsaPyCollector):
507563

508564
def __init__(self, sql_collector: DfsaSqlCollector):
Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from collections.abc import Iterable
2+
from pathlib import Path
3+
14
import astroid # type: ignore[import-untyped]
25
from databricks.labs.blueprint.wheels import ProductInfo
36

@@ -6,30 +9,55 @@
69

710

811
def main():
9-
# pylint: disable=too-many-nested-blocks
10-
codes = set()
12+
"""Walk the UCX code base to find all diagnostic linting codes."""
13+
codes = set[str]()
1114
product_info = ProductInfo.from_class(Advice)
1215
source_code = product_info.version_file().parent
13-
for file in source_code.glob("**/*.py"):
14-
maybe_tree = MaybeTree.from_source_code(file.read_text())
15-
if not maybe_tree.tree:
16-
continue
17-
tree = maybe_tree.tree
18-
# recursively detect values of "code" kwarg in calls
19-
for node in tree.walk():
20-
if not isinstance(node, astroid.Call):
21-
continue
22-
for keyword in node.keywords:
23-
name = keyword.arg
24-
if name != "code":
25-
continue
26-
if not isinstance(keyword.value, astroid.Const):
27-
continue
28-
problem_code = keyword.value.value
29-
codes.add(problem_code)
16+
for path in source_code.glob("**/*.py"):
17+
codes.update(_find_diagnostic_codes(path))
3018
for code in sorted(codes):
3119
print(code)
3220

3321

22+
def _find_diagnostic_codes(file: Path) -> Iterable[str]:
23+
"""Walk the Python ast tree to find the diagnostic codes."""
24+
maybe_tree = MaybeTree.from_source_code(file.read_text())
25+
if not maybe_tree.tree:
26+
return
27+
for node in maybe_tree.tree.walk():
28+
diagnostic_code = None
29+
if isinstance(node, astroid.ClassDef):
30+
diagnostic_code = _find_diagnostic_code_in_class_def(node)
31+
elif isinstance(node, astroid.Call):
32+
diagnostic_code = _find_diagnostic_code_in_call(node)
33+
if diagnostic_code:
34+
yield diagnostic_code
35+
36+
37+
def _find_diagnostic_code_in_call(node: astroid.Call) -> str | None:
38+
"""Find the diagnostic code in a call node."""
39+
for keyword in node.keywords:
40+
if keyword.arg == "code" and isinstance(keyword.value, astroid.Const):
41+
problem_code = keyword.value.value
42+
return problem_code
43+
return None
44+
45+
46+
def _find_diagnostic_code_in_class_def(node: astroid.ClassDef) -> str | None:
47+
"""Find the diagnostic code in a class definition node."""
48+
diagnostic_code_methods = []
49+
for child_node in node.body:
50+
if isinstance(child_node, astroid.FunctionDef) and child_node.name == "diagnostic_code":
51+
diagnostic_code_methods.append(child_node)
52+
if diagnostic_code_methods and diagnostic_code_methods[0].body:
53+
problem_code = diagnostic_code_methods[0].body[0].value.value
54+
return problem_code
55+
return None
56+
57+
3458
if __name__ == "__main__":
3559
main()
60+
61+
62+
def test_main():
63+
main()

0 commit comments

Comments
 (0)