|
4 | 4 | import dataclasses
|
5 | 5 | import locale
|
6 | 6 | import logging
|
| 7 | +import sys |
7 | 8 | from abc import abstractmethod, ABC
|
8 | 9 | from collections.abc import Iterable
|
9 |
| -from dataclasses import dataclass |
| 10 | +from dataclasses import dataclass, field |
| 11 | +from datetime import datetime |
10 | 12 | from pathlib import Path
|
| 13 | +from typing import Any |
11 | 14 |
|
12 | 15 | from astroid import AstroidSyntaxError, NodeNG # type: ignore
|
13 | 16 | from sqlglot import Expression, parse as parse_sql, ParseError as SqlParseError
|
|
19 | 22 |
|
20 | 23 | from databricks.labs.ucx.source_code.python.python_ast import Tree
|
21 | 24 |
|
| 25 | +if sys.version_info >= (3, 11): |
| 26 | + from typing import Self |
| 27 | +else: |
| 28 | + from typing_extensions import Self |
| 29 | + |
22 | 30 | # Code mapping between LSP, PyLint, and our own diagnostics:
|
23 | 31 | # | LSP | PyLint | Our |
|
24 | 32 | # |---------------------------|------------|----------------|
|
@@ -174,6 +182,140 @@ def name(self) -> str: ...
|
174 | 182 | def apply(self, code: str) -> str: ...
|
175 | 183 |
|
176 | 184 |
|
| 185 | +@dataclass |
| 186 | +class LineageAtom: |
| 187 | + |
| 188 | + object_type: str |
| 189 | + object_id: str |
| 190 | + other: dict[str, str] | None = None |
| 191 | + |
| 192 | + |
| 193 | +@dataclass |
| 194 | +class SourceInfo: |
| 195 | + |
| 196 | + @classmethod |
| 197 | + def from_dict(cls, data: dict[str, Any]) -> Self: |
| 198 | + source_lineage = data.get("source_lineage", None) |
| 199 | + if isinstance(source_lineage, list) and len(source_lineage) > 0 and isinstance(source_lineage[0], dict): |
| 200 | + lineage_atoms = [LineageAtom(**lineage) for lineage in source_lineage] |
| 201 | + data["source_lineage"] = lineage_atoms |
| 202 | + return cls(**data) |
| 203 | + |
| 204 | + UNKNOWN = "unknown" |
| 205 | + |
| 206 | + source_id: str = UNKNOWN |
| 207 | + source_timestamp: datetime = datetime.fromtimestamp(0) |
| 208 | + source_lineage: list[LineageAtom] = field(default_factory=list) |
| 209 | + assessment_start_timestamp: datetime = datetime.fromtimestamp(0) |
| 210 | + assessment_end_timestamp: datetime = datetime.fromtimestamp(0) |
| 211 | + |
| 212 | + def replace_source( |
| 213 | + self, |
| 214 | + source_id: str | None = None, |
| 215 | + source_lineage: list[LineageAtom] | None = None, |
| 216 | + source_timestamp: datetime | None = None, |
| 217 | + ): |
| 218 | + return dataclasses.replace( |
| 219 | + self, |
| 220 | + source_id=source_id or self.source_id, |
| 221 | + source_timestamp=source_timestamp or self.source_timestamp, |
| 222 | + source_lineage=source_lineage or self.source_lineage, |
| 223 | + ) |
| 224 | + |
| 225 | + def replace_assessment_infos( |
| 226 | + self, assessment_start: datetime | None = None, assessment_end: datetime | None = None |
| 227 | + ): |
| 228 | + return dataclasses.replace( |
| 229 | + self, |
| 230 | + assessment_start_timestamp=assessment_start or self.assessment_start_timestamp, |
| 231 | + assessment_end_timestamp=assessment_end or self.assessment_end_timestamp, |
| 232 | + ) |
| 233 | + |
| 234 | + |
| 235 | +@dataclass |
| 236 | +class UsedTable(SourceInfo): |
| 237 | + |
| 238 | + @classmethod |
| 239 | + def parse(cls, value: str, default_schema: str) -> UsedTable: |
| 240 | + parts = value.split(".") |
| 241 | + if len(parts) >= 3: |
| 242 | + catalog_name = parts.pop(0) |
| 243 | + else: |
| 244 | + catalog_name = "hive_metastore" |
| 245 | + if len(parts) >= 2: |
| 246 | + schema_name = parts.pop(0) |
| 247 | + else: |
| 248 | + schema_name = default_schema |
| 249 | + return UsedTable(catalog_name=catalog_name, schema_name=schema_name, table_name=parts[0]) |
| 250 | + |
| 251 | + catalog_name: str = SourceInfo.UNKNOWN |
| 252 | + schema_name: str = SourceInfo.UNKNOWN |
| 253 | + table_name: str = SourceInfo.UNKNOWN |
| 254 | + is_read: bool = True |
| 255 | + is_write: bool = False |
| 256 | + |
| 257 | + |
| 258 | +class TableCollector(ABC): |
| 259 | + |
| 260 | + @abstractmethod |
| 261 | + def collect_tables(self, source_code: str) -> Iterable[UsedTable]: ... |
| 262 | + |
| 263 | + |
| 264 | +@dataclass |
| 265 | +class TableInfoNode: |
| 266 | + table: UsedTable |
| 267 | + node: NodeNG |
| 268 | + |
| 269 | + |
| 270 | +class TablePyCollector(TableCollector, ABC): |
| 271 | + |
| 272 | + def collect_tables(self, source_code: str): |
| 273 | + tree = Tree.normalize_and_parse(source_code) |
| 274 | + for table_node in self.collect_tables_from_tree(tree): |
| 275 | + yield table_node.table |
| 276 | + |
| 277 | + @abstractmethod |
| 278 | + def collect_tables_from_tree(self, tree: Tree) -> Iterable[TableInfoNode]: ... |
| 279 | + |
| 280 | + |
| 281 | +class TableSqlCollector(TableCollector, ABC): ... |
| 282 | + |
| 283 | + |
| 284 | +@dataclass |
| 285 | +class DirectFsAccess(SourceInfo): |
| 286 | + """A record describing a Direct File System Access""" |
| 287 | + |
| 288 | + path: str = SourceInfo.UNKNOWN |
| 289 | + is_read: bool = False |
| 290 | + is_write: bool = False |
| 291 | + |
| 292 | + |
| 293 | +@dataclass |
| 294 | +class DirectFsAccessNode: |
| 295 | + dfsa: DirectFsAccess |
| 296 | + node: NodeNG |
| 297 | + |
| 298 | + |
| 299 | +class DfsaCollector(ABC): |
| 300 | + |
| 301 | + @abstractmethod |
| 302 | + def collect_dfsas(self, source_code: str) -> Iterable[DirectFsAccess]: ... |
| 303 | + |
| 304 | + |
| 305 | +class DfsaPyCollector(DfsaCollector, ABC): |
| 306 | + |
| 307 | + def collect_dfsas(self, source_code: str) -> Iterable[DirectFsAccess]: |
| 308 | + tree = Tree.normalize_and_parse(source_code) |
| 309 | + for dfsa_node in self.collect_dfsas_from_tree(tree): |
| 310 | + yield dfsa_node.dfsa |
| 311 | + |
| 312 | + @abstractmethod |
| 313 | + def collect_dfsas_from_tree(self, tree: Tree) -> Iterable[DirectFsAccessNode]: ... |
| 314 | + |
| 315 | + |
| 316 | +class DfsaSqlCollector(DfsaCollector, ABC): ... |
| 317 | + |
| 318 | + |
177 | 319 | # The default schema to use when the schema is not specified in a table reference
|
178 | 320 | # See: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-qry-select-usedb.html
|
179 | 321 | DEFAULT_CATALOG = 'hive_metastore'
|
@@ -221,20 +363,42 @@ def parse_security_mode(mode_str: str | None) -> compute.DataSecurityMode | None
|
221 | 363 | return None
|
222 | 364 |
|
223 | 365 |
|
224 |
| -class SqlSequentialLinter(SqlLinter): |
| 366 | +class SqlSequentialLinter(SqlLinter, DfsaCollector, TableCollector): |
225 | 367 |
|
226 |
| - def __init__(self, linters: list[SqlLinter]): |
| 368 | + def __init__( |
| 369 | + self, |
| 370 | + linters: list[SqlLinter], |
| 371 | + dfsa_collectors: list[DfsaSqlCollector], |
| 372 | + table_collectors: list[TableSqlCollector], |
| 373 | + ): |
227 | 374 | self._linters = linters
|
| 375 | + self._dfsa_collectors = dfsa_collectors |
| 376 | + self._table_collectors = table_collectors |
228 | 377 |
|
229 | 378 | def lint_expression(self, expression: Expression) -> Iterable[Advice]:
|
230 | 379 | for linter in self._linters:
|
231 | 380 | yield from linter.lint_expression(expression)
|
232 | 381 |
|
| 382 | + def collect_dfsas(self, source_code: str) -> Iterable[DirectFsAccess]: |
| 383 | + for collector in self._dfsa_collectors: |
| 384 | + yield from collector.collect_dfsas(source_code) |
| 385 | + |
| 386 | + def collect_tables(self, source_code: str) -> Iterable[UsedTable]: |
| 387 | + for collector in self._table_collectors: |
| 388 | + yield from collector.collect_tables(source_code) |
233 | 389 |
|
234 |
| -class PythonSequentialLinter(Linter): |
235 | 390 |
|
236 |
| - def __init__(self, linters: list[PythonLinter]): |
| 391 | +class PythonSequentialLinter(Linter, DfsaCollector, TableCollector): |
| 392 | + |
| 393 | + def __init__( |
| 394 | + self, |
| 395 | + linters: list[PythonLinter], |
| 396 | + dfsa_collectors: list[DfsaPyCollector], |
| 397 | + table_collectors: list[TablePyCollector], |
| 398 | + ): |
237 | 399 | self._linters = linters
|
| 400 | + self._dfsa_collectors = dfsa_collectors |
| 401 | + self._table_collectors = table_collectors |
238 | 402 | self._tree: Tree | None = None
|
239 | 403 |
|
240 | 404 | def lint(self, code: str) -> Iterable[Advice]:
|
@@ -271,6 +435,30 @@ def process_child_cell(self, code: str):
|
271 | 435 | # error already reported when linting enclosing notebook
|
272 | 436 | logger.warning(f"Failed to parse Python cell: {code}", exc_info=e)
|
273 | 437 |
|
| 438 | + def collect_dfsas(self, source_code: str) -> Iterable[DirectFsAccess]: |
| 439 | + try: |
| 440 | + tree = self._parse_and_append(source_code) |
| 441 | + for dfsa_node in self.collect_dfsas_from_tree(tree): |
| 442 | + yield dfsa_node.dfsa |
| 443 | + except AstroidSyntaxError as e: |
| 444 | + logger.warning('syntax-error', exc_info=e) |
| 445 | + |
| 446 | + def collect_dfsas_from_tree(self, tree: Tree) -> Iterable[DirectFsAccessNode]: |
| 447 | + for collector in self._dfsa_collectors: |
| 448 | + yield from collector.collect_dfsas_from_tree(tree) |
| 449 | + |
| 450 | + def collect_tables(self, source_code: str) -> Iterable[UsedTable]: |
| 451 | + try: |
| 452 | + tree = self._parse_and_append(source_code) |
| 453 | + for table_node in self.collect_tables_from_tree(tree): |
| 454 | + yield table_node.table |
| 455 | + except AstroidSyntaxError as e: |
| 456 | + logger.warning('syntax-error', exc_info=e) |
| 457 | + |
| 458 | + def collect_tables_from_tree(self, tree: Tree) -> Iterable[TableInfoNode]: |
| 459 | + for collector in self._table_collectors: |
| 460 | + yield from collector.collect_tables_from_tree(tree) |
| 461 | + |
274 | 462 | def _make_tree(self) -> Tree:
|
275 | 463 | if self._tree is None:
|
276 | 464 | self._tree = Tree.new_module()
|
|
0 commit comments