Skip to content

Commit 8b3181f

Browse files
authored
feat: add find_variables ast-helper (#212)
* feat: find_variables * fix has_returns * add test has_returns
1 parent fb25066 commit 8b3181f

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

python/py_helpers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ def has_returns(self, returns_str):
118118
return returns_str == self.tree.returns.id
119119
elif isinstance(self.tree.returns, ast.Constant):
120120
return returns_str == self.tree.returns.value
121+
elif isinstance((ann := self.tree.returns), ast.Subscript):
122+
return Node(ann).is_equivalent(returns_str)
121123
return False
122124

123125
def find_body(self):
@@ -251,6 +253,25 @@ def find_variable(self, name):
251253
return Node(node)
252254
return Node()
253255

256+
def find_variables(self, name):
257+
assignments = self._find_all((ast.Assign, ast.AnnAssign))
258+
var_list = []
259+
for node in assignments:
260+
if isinstance(node.tree, ast.Assign):
261+
for target in node.tree.targets:
262+
if isinstance(target, ast.Name):
263+
if target.id == name:
264+
var_list.append(node)
265+
if isinstance(target, ast.Attribute):
266+
names = name.split(".")
267+
if target.value.id == names[0] and target.attr == names[1]:
268+
var_list.append(node)
269+
elif isinstance(node.tree, ast.AnnAssign):
270+
if isinstance(node.tree.target, ast.Name):
271+
if node.tree.target.id == name:
272+
var_list.append(node)
273+
return var_list
274+
254275
# find variable incremented or decremented using += or -=
255276
def find_aug_variable(self, name):
256277
if not self._has_body():

python/py_helpers.test.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,23 @@ def foo():
161161
)
162162
self.assertEqual(node.find_function("foo").find_aug_variable("x"), Node())
163163

164+
def test_find_variables(self):
165+
code_str = """
166+
x: int = 0
167+
a.b = 0
168+
x = 5
169+
a.b = 2
170+
x = 10
171+
"""
172+
node = Node(code_str)
173+
self.assertEqual(len(node.find_variables("x")), 3)
174+
self.assertTrue(node.find_variables("x")[0].is_equivalent("x: int = 0"))
175+
self.assertTrue(node.find_variables("x")[1].is_equivalent("x = 5"))
176+
self.assertTrue(node.find_variables("x")[2].is_equivalent("x = 10"))
177+
self.assertEqual(len(node.find_variables("a.b")), 2)
178+
self.assertTrue(node.find_variables("a.b")[0].is_equivalent("a.b = 0"))
179+
self.assertTrue(node.find_variables("a.b")[1].is_equivalent("a.b = 2"))
180+
164181

165182
class TestFunctionAndClassHelpers(unittest.TestCase):
166183
def test_find_function_returns_node(self):
@@ -295,12 +312,16 @@ def foo(a: int, b: int) -> int:
295312
def test_has_returns(self):
296313
code_str = """
297314
def foo() -> int:
298-
pass
315+
pass
316+
317+
def spam() -> Dict[str, int]:
318+
pass
299319
"""
300320
node = Node(code_str)
301321

302322
self.assertTrue(node.find_function("foo").has_returns("int"))
303323
self.assertFalse(node.find_function("foo").has_returns("str"))
324+
self.assertTrue(node.find_function("spam").has_returns("Dict[str, int]"))
304325

305326
def test_has_returns_without_returns(self):
306327
code_str = """

0 commit comments

Comments
 (0)