Skip to content

Commit 49af2f5

Browse files
committed
Introduce WIP DirectFsAccessPyFixer for code replacement
1 parent a77ca8b commit 49af2f5

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

src/databricks/labs/ucx/source_code/linters/directfs.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import logging
22
from abc import ABC
33
from collections.abc import Iterable
4+
from typing import Any
45

56
from astroid import Call, InferenceError, NodeNG # type: ignore
67
from sqlglot.expressions import Alter, Create, Delete, Drop, Expression, Identifier, Insert, Literal, Select
78

9+
from databricks.labs.ucx.hive_metastore import TablesCrawler
810
from databricks.labs.ucx.source_code.base import (
911
Advice,
1012
Deprecation,
@@ -14,6 +16,7 @@
1416
DfsaSqlCollector,
1517
DirectFsAccess,
1618
)
19+
from databricks.labs.ucx.source_code.directfs_access import DirectFsAccessCrawler
1720
from databricks.labs.ucx.source_code.python.python_ast import (
1821
Tree,
1922
TreeVisitor,
@@ -207,3 +210,77 @@ def _walk_up(cls, expression: Expression | None) -> Expression | None:
207210
if isinstance(expression, (Create, Alter, Drop, Insert, Delete, Select)):
208211
return expression
209212
return cls._walk_up(expression.parent)
213+
214+
class DirectFsAccessPyFixer(DirectFsAccessPyLinter):
215+
def __init__(self,
216+
session_state: CurrentSessionState,
217+
directfs_crawler: DirectFsAccessCrawler,
218+
tables_crawler: TablesCrawler,
219+
prevent_spark_duplicates=True,
220+
):
221+
super().__init__(session_state, prevent_spark_duplicates)
222+
self.directfs_crawler = directfs_crawler
223+
self.tables_crawler = tables_crawler
224+
self.direct_fs_table_list = [Any, [dict[str,str], Any]]
225+
226+
def fix_tree(self, tree: Tree) -> Tree:
227+
for directfs_node in self.collect_dfsas_from_tree(tree):
228+
self._fix_node(directfs_node)
229+
return tree
230+
231+
def _fix_node(self, directfs_node: DirectFsAccessNode) -> None:
232+
dfsa = directfs_node.dfsa
233+
if dfsa.is_read:
234+
self._replace_read(directfs_node)
235+
elif dfsa.is_write:
236+
self._replace_write(directfs_node)
237+
238+
def _replace_read(self, directfs_node: DirectFsAccessNode) -> None:
239+
dfsa = directfs_node.dfsa
240+
dfsa_details = self.direct_fs_table_list[dfsa.path]
241+
242+
# TODO: Actual code replacement
243+
logger.info(f"Replacing read of {dfsa.path} with table {dfsa_details.dst_schema}.{dfsa_details.dst_table}")
244+
245+
def _replace_write(self, directfs_node):
246+
dfsa = directfs_node.dfsa
247+
logger.info(f"Replacing read of {dfsa.path} with table")
248+
249+
def populate_directfs_table_list(
250+
self,
251+
directfs_crawlers: list[DirectFsAccessCrawler],
252+
tables_crawler: TablesCrawler,
253+
workspace_name: str,
254+
catalog_name: str,
255+
) -> None:
256+
"""
257+
List all direct filesystem access records.
258+
"""
259+
directfs_snapshot = []
260+
for crawler in directfs_crawlers:
261+
for directfs_access in crawler.snapshot():
262+
directfs_snapshot.append(directfs_access)
263+
tables_snapshot = list(tables_crawler.snapshot())
264+
if not tables_snapshot:
265+
msg = "No tables found. Please run: databricks labs ucx ensure-assessment-run"
266+
raise ValueError(msg)
267+
if not directfs_snapshot:
268+
msg = "No directfs references found in code"
269+
raise ValueError(msg)
270+
271+
# TODO: very inefficient search, just for initial testing
272+
#
273+
for table in tables_snapshot:
274+
for directfs_record in directfs_snapshot:
275+
if table.location:
276+
if directfs_record.path in table.location:
277+
self.direct_fs_table_list.append({
278+
directfs_record.path:{
279+
"workspace_name":workspace_name,
280+
"is_read":directfs_record.is_read,
281+
"is_write":directfs_record.is_write,
282+
"catalog_name":catalog_name,
283+
"dst_schema":table.database,
284+
"dst_table":table.name,
285+
}
286+
})

0 commit comments

Comments
 (0)