2
2
from abc import ABC
3
3
from collections .abc import Iterable
4
4
5
- from astroid import Attribute , Call , Const , InferenceError , JoinedStr , Name , NodeNG # type: ignore
5
+ from astroid import Call , InferenceError , NodeNG # type: ignore
6
6
from sqlglot .expressions import Alter , Create , Delete , Drop , Expression , Identifier , Insert , Literal , Select
7
7
8
8
from databricks .labs .ucx .source_code .base import (
16
16
DfsaSqlCollector ,
17
17
DirectFsAccess ,
18
18
)
19
- from databricks .labs .ucx .source_code .python .python_ast import Tree , TreeVisitor
19
+ from databricks .labs .ucx .source_code .python .python_ast import Tree , TreeVisitor , TreeHelper
20
20
from databricks .labs .ucx .source_code .python .python_infer import InferredValue
21
21
from databricks .labs .ucx .source_code .sql .sql_parser import SqlParser , SqlExpression
22
22
@@ -68,43 +68,37 @@ class _DetectDirectFsAccessVisitor(TreeVisitor):
68
68
def __init__ (self , session_state : CurrentSessionState , prevent_spark_duplicates : bool ) -> None :
69
69
self ._session_state = session_state
70
70
self ._directfs_nodes : list [DirectFsAccessNode ] = []
71
- self ._reported_locations : set [tuple [int , int ]] = set ()
72
71
self ._prevent_spark_duplicates = prevent_spark_duplicates
73
72
74
73
def visit_call (self , node : Call ):
75
74
for arg in node .args :
76
- self ._visit_arg (arg )
75
+ self ._visit_arg (node , arg )
77
76
78
- def _visit_arg (self , arg : NodeNG ):
77
+ def _visit_arg (self , call : Call , arg : NodeNG ):
79
78
try :
80
79
for inferred in InferredValue .infer_from_node (arg , self ._session_state ):
81
80
if not inferred .is_inferred ():
82
81
logger .debug (f"Could not infer value of { arg .as_string ()} " )
83
82
continue
84
- self ._check_str_constant ( arg , inferred )
83
+ self ._check_str_arg ( call , arg , inferred )
85
84
except InferenceError as e :
86
85
logger .debug (f"Could not infer value of { arg .as_string ()} " , exc_info = e )
87
86
88
- def visit_const (self , node : Const ):
89
- # Constant strings yield Advisories
90
- if isinstance (node .value , str ):
91
- self ._check_str_constant (node , InferredValue ([node ]))
92
-
93
- def _check_str_constant (self , source_node : NodeNG , inferred : InferredValue ):
94
- if self ._already_reported (source_node , inferred ):
95
- return
96
- # don't report on JoinedStr fragments
97
- if isinstance (source_node .parent , JoinedStr ):
98
- return
87
+ def _check_str_arg (self , call_node : Call , arg_node : NodeNG , inferred : InferredValue ):
99
88
value = inferred .as_string ()
100
89
for pattern in DIRECT_FS_ACCESS_PATTERNS :
101
90
if not pattern .matches (value ):
102
91
continue
103
- # avoid false positives with relative URLs
104
- if self ._is_http_call_parameter (source_node ):
92
+ # only capture 'open' calls or calls originating from spark or dbutils
93
+ # because there is no other known way to manipulate data directly from file system
94
+ tree = Tree (call_node )
95
+ is_open = TreeHelper .get_call_name (call_node ) == "open" and tree .is_builtin ()
96
+ is_from_db_utils = False if is_open else tree .is_from_module ("dbutils" )
97
+ is_from_spark = False if is_open or is_from_db_utils else tree .is_from_module ("spark" )
98
+ if not (is_open or is_from_db_utils or is_from_spark ):
105
99
return
106
100
# avoid duplicate advices that are reported by SparkSqlPyLinter
107
- if self ._prevent_spark_duplicates and Tree ( source_node ). is_from_module ( "spark" ) :
101
+ if self ._prevent_spark_duplicates and is_from_spark :
108
102
return
109
103
# since we're normally filtering out spark calls, we're dealing with dfsas we know little about
110
104
# notably we don't know is_read or is_write
@@ -113,39 +107,8 @@ def _check_str_constant(self, source_node: NodeNG, inferred: InferredValue):
113
107
is_read = True ,
114
108
is_write = False ,
115
109
)
116
- self ._directfs_nodes .append (DirectFsAccessNode (dfsa , source_node ))
117
- self ._reported_locations .add ((source_node .lineno , source_node .col_offset ))
118
-
119
- @classmethod
120
- def _is_http_call_parameter (cls , source_node : NodeNG ):
121
- if not isinstance (source_node .parent , Call ):
122
- return False
123
- # for now we only cater for ws.api_client.do
124
- return cls ._is_ws_api_client_do_call (source_node )
125
-
126
- @classmethod
127
- def _is_ws_api_client_do_call (cls , source_node : NodeNG ):
128
- assert isinstance (source_node .parent , Call )
129
- func = source_node .parent .func
130
- if not isinstance (func , Attribute ) or func .attrname != "do" :
131
- return False
132
- expr = func .expr
133
- if not isinstance (expr , Attribute ) or expr .attrname != "api_client" :
134
- return False
135
- expr = expr .expr
136
- if not isinstance (expr , Name ):
137
- return False
138
- for value in InferredValue .infer_from_node (expr ):
139
- if not value .is_inferred ():
140
- continue
141
- for node in value .nodes :
142
- return Tree (node ).is_instance_of ("WorkspaceClient" )
143
- # at this point is seems safer to assume that expr.expr is a workspace than the opposite
144
- return True
145
-
146
- def _already_reported (self , source_node : NodeNG , inferred : InferredValue ):
147
- all_nodes = [source_node ] + inferred .nodes
148
- return any ((node .lineno , node .col_offset ) in self ._reported_locations for node in all_nodes )
110
+ self ._directfs_nodes .append (DirectFsAccessNode (dfsa , arg_node ))
111
+ return
149
112
150
113
@property
151
114
def directfs_nodes (self ):
0 commit comments