1
1
import pytest
2
- from astroid import Assign , Attribute , Call , Const , Expr , Module , Name # type: ignore
2
+ import astroid # type: ignore
3
+ from astroid import Assign , AssignName , Attribute , Call , Const , Expr , Module , Name # type: ignore
3
4
4
5
from databricks .labs .ucx .source_code .python .python_ast import Tree , TreeHelper
5
6
from databricks .labs .ucx .source_code .python .python_infer import InferredValue
@@ -139,21 +140,27 @@ def test_ignores_magic_marker_in_multiline_comment() -> None:
139
140
assert True
140
141
141
142
142
- def test_appends_statements () -> None :
143
- source_1 = "a = 'John'"
144
- maybe_tree_1 = Tree .maybe_normalized_parse (source_1 )
145
- assert maybe_tree_1 .tree is not None , maybe_tree_1 .failure
146
- tree_1 = maybe_tree_1 .tree
147
- source_2 = 'b = f"Hello {a}!"'
148
- maybe_tree_2 = Tree .maybe_normalized_parse (source_2 )
149
- assert maybe_tree_2 .tree is not None , maybe_tree_2 .failure
150
- tree_2 = maybe_tree_2 .tree
151
- tree_3 = tree_1 .append_tree (tree_2 )
152
- nodes = tree_3 .locate (Assign , [])
153
- tree = Tree (nodes [0 ].value ) # tree_3 only contains tree_2 statements
154
- values = list (InferredValue .infer_from_node (tree .node ))
155
- strings = list (value .as_string () for value in values )
156
- assert strings == ["Hello John!" ]
143
+ def test_tree_attach_child_tree_infers_value () -> None :
144
+ """Attaching trees allows traversing from both parent and child."""
145
+ inferred_string = "Hello John!"
146
+ parent_source , child_source = "a = 'John'" , 'b = f"Hello {a}!"'
147
+ parent_maybe_tree = Tree .maybe_normalized_parse (parent_source )
148
+ child_maybe_tree = Tree .maybe_normalized_parse (child_source )
149
+
150
+ assert parent_maybe_tree .tree is not None , parent_maybe_tree .failure
151
+ assert child_maybe_tree .tree is not None , child_maybe_tree .failure
152
+
153
+ parent_maybe_tree .tree .attach_child_tree (child_maybe_tree .tree )
154
+
155
+ nodes = parent_maybe_tree .tree .locate (Assign , [])
156
+ tree = Tree (nodes [1 ].value ) # Starting from the parent, we are looking for the last assign
157
+ strings = [value .as_string () for value in InferredValue .infer_from_node (tree .node )]
158
+ assert strings == [inferred_string ]
159
+
160
+ nodes = child_maybe_tree .tree .locate (Assign , [])
161
+ tree = Tree (nodes [0 ].value ) # Starting from child, we are looking for the first assign
162
+ strings = [value .as_string () for value in InferredValue .infer_from_node (tree .node )]
163
+ assert strings == [inferred_string ]
157
164
158
165
159
166
def test_is_from_module () -> None :
@@ -194,28 +201,23 @@ def test_is_instance_of(source, name, class_name) -> None:
194
201
assert Tree (var [0 ]).is_instance_of (class_name )
195
202
196
203
197
- def test_supports_recursive_refs_when_checking_module () -> None :
198
- source_1 = """
199
- df = spark.read.csv("hi")
200
- """
201
- source_2 = """
202
- df = df.withColumn(stuff)
203
- """
204
- source_3 = """
205
- df = df.withColumn(stuff2)
206
- """
207
- maybe_tree = Tree .maybe_normalized_parse (source_1 )
208
- assert maybe_tree .tree is not None , maybe_tree .failure
209
- main_tree = maybe_tree .tree
210
- maybe_tree_2 = Tree .maybe_normalized_parse (source_2 )
211
- assert maybe_tree_2 .tree is not None , maybe_tree_2 .failure
212
- tree_2 = maybe_tree_2 .tree
213
- main_tree .append_tree (tree_2 )
214
- maybe_tree_3 = Tree .maybe_normalized_parse (source_3 )
215
- assert maybe_tree_3 .tree is not None , maybe_tree_3 .failure
216
- tree_3 = maybe_tree_3 .tree
217
- main_tree .append_tree (tree_3 )
218
- assign = tree_3 .locate (Assign , [])[0 ]
204
+ def test_tree_attach_child_tree_propagates_module_reference () -> None :
205
+ """The spark module should propagate from the parent tree."""
206
+ source_1 = "df = spark.read.csv('hi')"
207
+ source_2 = "df = df.withColumn(stuff)"
208
+ source_3 = "df = df.withColumn(stuff2)"
209
+ first_line_maybe_tree = Tree .maybe_normalized_parse (source_1 )
210
+ second_line_maybe_tree = Tree .maybe_normalized_parse (source_2 )
211
+ third_line_maybe_tree = Tree .maybe_normalized_parse (source_3 )
212
+
213
+ assert first_line_maybe_tree .tree , first_line_maybe_tree .failure
214
+ assert second_line_maybe_tree .tree , second_line_maybe_tree .failure
215
+ assert third_line_maybe_tree .tree , third_line_maybe_tree .failure
216
+
217
+ first_line_maybe_tree .tree .attach_child_tree (second_line_maybe_tree .tree )
218
+ first_line_maybe_tree .tree .attach_child_tree (third_line_maybe_tree .tree )
219
+
220
+ assign = third_line_maybe_tree .tree .locate (Assign , [])[0 ]
219
221
assert Tree (assign .value ).is_from_module ("spark" )
220
222
221
223
@@ -302,6 +304,53 @@ def test_is_builtin(source, name, is_builtin) -> None:
302
304
assert False # could not locate call
303
305
304
306
307
+ def test_tree_attach_nodes_sets_parent () -> None :
308
+ node = astroid .extract_node ("b = a + 2" )
309
+ maybe_tree = Tree .maybe_normalized_parse ("a = 1" )
310
+ assert maybe_tree .tree , maybe_tree .failure
311
+
312
+ maybe_tree .tree .attach_nodes ([node ])
313
+
314
+ assert node .parent == maybe_tree .tree .node
315
+
316
+
317
+ def test_tree_attach_nodes_adds_node_to_body () -> None :
318
+ node = astroid .extract_node ("b = a + 2" )
319
+ maybe_tree = Tree .maybe_normalized_parse ("a = 1" )
320
+ assert maybe_tree .tree , maybe_tree .failure
321
+
322
+ maybe_tree .tree .attach_nodes ([node ])
323
+
324
+ assert maybe_tree .tree .node .body [- 1 ] == node
325
+
326
+
327
+ def test_tree_extend_globals_adds_assign_name_to_tree () -> None :
328
+ maybe_tree = Tree .maybe_normalized_parse ("a = 1" )
329
+ assert maybe_tree .tree , maybe_tree .failure
330
+
331
+ node = astroid .extract_node ("b = a + 2" )
332
+ assign_name = next (node .get_children ())
333
+ assert isinstance (assign_name , AssignName )
334
+
335
+ maybe_tree .tree .extend_globals ({"b" : [assign_name ]})
336
+
337
+ assert isinstance (maybe_tree .tree .node , Module )
338
+ assert maybe_tree .tree .node .globals .get ("b" ) == [assign_name ]
339
+
340
+
341
+ def test_tree_attach_child_tree_appends_globals_to_parent_tree () -> None :
342
+ parent_tree = Tree .maybe_normalized_parse ("a = 1" )
343
+ child_tree = Tree .maybe_normalized_parse ("b = a + 2" )
344
+
345
+ assert parent_tree .tree , parent_tree .failure
346
+ assert child_tree .tree , child_tree .failure
347
+
348
+ parent_tree .tree .attach_child_tree (child_tree .tree )
349
+
350
+ assert set (parent_tree .tree .node .globals .keys ()) == {"a" , "b" }
351
+ assert set (child_tree .tree .node .globals .keys ()) == {"b" }
352
+
353
+
305
354
def test_first_statement_is_none () -> None :
306
355
node = Const ("xyz" )
307
356
assert not Tree (node ).first_statement ()
@@ -311,14 +360,19 @@ def test_repr_is_truncated() -> None:
311
360
assert len (repr (Tree (Const ("xyz" )))) <= (32 + len ("..." ) + len ("<Tree: >" ))
312
361
313
362
314
- def test_append_tree_fails () -> None :
315
- with pytest .raises (NotImplementedError ):
316
- Tree (Const ("xyz" )).append_tree (Tree (Const ("xyz" )))
363
+ def test_tree_attach_child_tree_raises_not_implemented_error_for_constant_node () -> None :
364
+ with pytest .raises (NotImplementedError , match = "Cannot attach child tree: .*" ):
365
+ Tree (Const ("xyz" )).attach_child_tree (Tree (Const ("xyz" )))
317
366
318
367
319
- def test_append_node_fails () -> None :
320
- with pytest .raises (NotImplementedError ):
321
- Tree (Const ("xyz" )).append_nodes ([])
368
+ def test_tree_attach_nodes_raises_not_implemented_error_for_constant_node () -> None :
369
+ with pytest .raises (NotImplementedError , match = "Cannot attach nodes to: .*" ):
370
+ Tree (Const ("xyz" )).attach_nodes ([])
371
+
372
+
373
+ def test_extend_globals_raises_not_implemented_error_for_constant_node () -> None :
374
+ with pytest .raises (NotImplementedError , match = "Cannot extend globals to: .*" ):
375
+ Tree (Const ("xyz" )).extend_globals ({})
322
376
323
377
324
378
def test_nodes_between_fails () -> None :
@@ -330,11 +384,6 @@ def test_has_global_fails() -> None:
330
384
assert not Tree .new_module ().has_global ("xyz" )
331
385
332
386
333
- def test_append_globals_fails () -> None :
334
- with pytest .raises (NotImplementedError ):
335
- Tree (Const ("xyz" )).append_globals ({})
336
-
337
-
338
387
def test_globals_between_fails () -> None :
339
388
with pytest .raises (NotImplementedError ):
340
389
Tree (Const ("xyz" )).line_count ()
0 commit comments