Skip to content

Commit 396e5c7

Browse files
committed
improve ast
1 parent fe47c9a commit 396e5c7

File tree

2 files changed

+75
-12
lines changed

2 files changed

+75
-12
lines changed

gdtoolkit/common/ast.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ def __init__(self, node: Tree):
1919

2020
def _load_sub_statements(self):
2121
if self.kind == "class_def":
22-
pass
23-
elif self.kind == "property_body_def":
24-
pass
25-
elif self.kind in ["func_def", "static_func_def"]:
22+
raise NotImplementedError
23+
if self.kind == "property_body_def":
24+
raise NotImplementedError
25+
if self.kind in ["func_def", "static_func_def"]:
2626
self.sub_statements = [Statement(n) for n in self.lark_node.children[1:]]
2727
elif self.kind == "if_stmt":
2828
for branch in self.lark_node.children:
@@ -31,16 +31,22 @@ def _load_sub_statements(self):
3131
else:
3232
self.sub_statements += [Statement(n) for n in branch.children]
3333
elif self.kind == "while_stmt":
34-
pass
34+
self.sub_statements = [Statement(n) for n in self.lark_node.children[1:]]
3535
elif self.kind == "for_stmt":
36-
pass
36+
self.sub_statements = [Statement(n) for n in self.lark_node.children[2:]]
3737
elif self.kind == "match_stmt":
38-
pass
38+
for branch in self.lark_node.children:
39+
self.sub_statements += [Statement(n) for n in branch.children[1:]]
3940
for sub_statement in self.sub_statements:
4041
self.all_sub_statements += [
4142
sub_statement
4243
] + sub_statement.all_sub_statements
4344

45+
def __repr__(self):
46+
return "Statement({}:{}:{})".format(
47+
self.lark_node.data, self.lark_node.line, self.lark_node.column
48+
)
49+
4450

4551
# pylint: disable=too-few-public-methods
4652
class Parameter:
@@ -50,15 +56,15 @@ def __init__(self, node: Tree):
5056
self.name = node.children[0].value
5157

5258

53-
# TODO: inherit from statement
5459
# pylint: disable=too-few-public-methods
55-
class Function:
60+
class Function(Statement):
5661
"""Abstract representation of function"""
5762

5863
def __init__(self, func_def: Tree):
59-
self.lark_node = func_def
64+
super().__init__(func_def)
6065
self.name = ""
6166
self.parameters = [] # type: List[Parameter]
67+
self.all_statements = [self] + self.all_sub_statements # type: ignore
6268

6369
self._load_data_from_func_def(func_def)
6470

@@ -72,8 +78,6 @@ def _load_data_from_func_def(self, func_def: Tree) -> None:
7278
for c in func_args.children # type: ignore
7379
if c.data != "trailing_comma"
7480
]
75-
slf = Statement(self.lark_node)
76-
self.all_statements = [slf] + slf.all_sub_statements
7781

7882

7983
# pylint: disable=too-few-public-methods

tests/test_ast.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from gdtoolkit.parser import parser
2+
from gdtoolkit.common.ast import AbstractSyntaxTree
3+
4+
5+
def test_toplevel():
6+
code = """func foo():
7+
pass
8+
"""
9+
parse_tree = parser.parse(code, gather_metadata=True)
10+
ast = AbstractSyntaxTree(parse_tree)
11+
assert len(ast.all_functions) == 1
12+
function = ast.all_functions[0]
13+
assert function.name == "foo"
14+
assert len(function.all_sub_statements) == 1
15+
16+
17+
def test_all_sub_statements_of_function():
18+
code = """func foo():
19+
pass
20+
if true:
21+
pass
22+
if true:
23+
return
24+
elif true:
25+
return
26+
else:
27+
return
28+
while true:
29+
if true:
30+
return
31+
for f in range(10):
32+
if true:
33+
return
34+
var x
35+
match(x):
36+
1:
37+
pass
38+
"""
39+
parse_tree = parser.parse(code, gather_metadata=True)
40+
ast = AbstractSyntaxTree(parse_tree)
41+
assert len(ast.all_functions) == 1
42+
function = ast.all_functions[0]
43+
assert len(function.all_sub_statements) == 16
44+
assert function.all_sub_statements[0].kind == 'pass_stmt'
45+
assert function.all_sub_statements[1].kind == 'if_stmt'
46+
assert function.all_sub_statements[2].kind == 'pass_stmt'
47+
assert function.all_sub_statements[3].kind == 'if_stmt'
48+
assert function.all_sub_statements[4].kind == 'return_stmt'
49+
assert function.all_sub_statements[5].kind == 'return_stmt'
50+
assert function.all_sub_statements[6].kind == 'return_stmt'
51+
assert function.all_sub_statements[7].kind == 'while_stmt'
52+
assert function.all_sub_statements[8].kind == 'if_stmt'
53+
assert function.all_sub_statements[9].kind == 'return_stmt'
54+
assert function.all_sub_statements[10].kind == 'for_stmt'
55+
assert function.all_sub_statements[11].kind == 'if_stmt'
56+
assert function.all_sub_statements[12].kind == 'return_stmt'
57+
assert function.all_sub_statements[13].kind == 'func_var_stmt'
58+
assert function.all_sub_statements[14].kind == 'match_stmt'
59+
assert function.all_sub_statements[15].kind == 'pass_stmt'

0 commit comments

Comments
 (0)