Skip to content

Handle nonlocal and global #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 54 additions & 23 deletions auto_typing_final/finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,21 @@ def texts_of_identifier_nodes(node: SgNode) -> Iterable[str]:
return (child.text() for child in node.children() if child.kind() == "identifier")


def find_identifiers_in_function_body(node: SgNode) -> Iterable[str]: # noqa: C901, PLR0912
def node_is_in_inner_function(root: SgNode, node: SgNode) -> bool:
for ancestor in node.ancestors():
if ancestor.kind() == "function_definition":
return ancestor != root
return False


def node_is_in_inner_function_or_class(root: SgNode, node: SgNode) -> bool:
for ancestor in node.ancestors():
if ancestor.kind() in {"function_definition", "class_definition"}:
return ancestor != root
return False


def find_identifiers_in_function_body(node: SgNode) -> Iterable[str]: # noqa: C901, PLR0912, PLR0915
match node.kind():
case "assignment" | "augmented_assignment":
if left := node.field("left"):
Expand All @@ -23,9 +37,22 @@ def find_identifiers_in_function_body(node: SgNode) -> Iterable[str]: # noqa: C
yield from texts_of_identifier_nodes(left)
case "identifier":
yield left.text()
case "function_definition" | "class_definition" | "named_expression":
case "named_expression":
if name := node.field("name"):
yield name.text()
case "class_definition":
if name := node.field("name"):
yield name.text()
for function in node.find_all(kind="function_definition"):
for nonlocal_statement in node.find_all(kind="nonlocal_statement"):
if not node_is_in_inner_function(root=function, node=nonlocal_statement):
yield from texts_of_identifier_nodes(nonlocal_statement)
case "function_definition":
if name := node.field("name"):
yield name.text()
for nonlocal_statement in node.find_all(kind="nonlocal_statement"):
if not node_is_in_inner_function(root=node, node=nonlocal_statement):
yield from texts_of_identifier_nodes(nonlocal_statement)
case "import_from_statement":
match tuple((child.kind(), child) for child in node.children()):
case (("from", _), _, ("import", _), *name_nodes):
Expand Down Expand Up @@ -57,7 +84,7 @@ def find_identifiers_in_function_body(node: SgNode) -> Iterable[str]: # noqa: C
and (last_last_child := last_child_of_type(last_child, "identifier"))
):
yield last_last_child.text()
case "splat_pattern":
case "splat_pattern" | "global_statement" | "nonlocal_statement":
yield from texts_of_identifier_nodes(node)
case "dict_pattern":
for child in node.children():
Expand Down Expand Up @@ -109,16 +136,9 @@ def find_identifiers_in_function_parameter(node: SgNode) -> Iterable[str]:
}


def node_is_in_inner_function_or_class(root: SgNode, node: SgNode) -> bool:
for ancestor in node.ancestors():
if ancestor.kind() in {"function_definition", "class_definition"}:
return ancestor != root
return False


def find_definitions_in_scope_grouped_by_name(root: SgNode) -> Iterable[list[SgNode]]:
def find_definitions_in_scope_grouped_by_name(root: SgNode) -> dict[str, list[SgNode]]:
definition_map = defaultdict(list)
ignored_names = set[str]()

if parameters := root.field("parameters"):
for node in parameters.children():
for identifier in find_identifiers_in_function_parameter(node):
Expand All @@ -127,14 +147,25 @@ def find_definitions_in_scope_grouped_by_name(root: SgNode) -> Iterable[list[SgN
for node in root.find_all(rule):
if node_is_in_inner_function_or_class(root, node):
continue
match node.kind():
case "global_statement" | "nonlocal_statement":
ignored_names.update(texts_of_identifier_nodes(node))
case _:
for identifier in find_identifiers_in_function_body(node):
definition_map[identifier].append(node)

for param in ignored_names:
if param in definition_map:
del definition_map[param]
return definition_map.values()
for identifier in find_identifiers_in_function_body(node):
definition_map[identifier].append(node)

return definition_map


def find_definitions_in_global_scope(root: SgNode) -> dict[str, list[SgNode]]:
global_statement_identifiers = defaultdict(list)
for node in root.find_all(kind="global_statement"):
for identifier in texts_of_identifier_nodes(node):
global_statement_identifiers[identifier].append(node)

return {
identifier: (global_statement_identifiers[identifier] + definitions)
for identifier, definitions in find_definitions_in_scope_grouped_by_name(root).items()
}


def find_definitions_in_module(root: SgNode) -> Iterable[list[SgNode]]:
for function in root.find_all(kind="function_definition"):
yield from find_definitions_in_scope_grouped_by_name(function).values()
yield from find_definitions_in_global_scope(root).values()
16 changes: 6 additions & 10 deletions auto_typing_final/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ast_grep_py import Edit, SgNode, SgRoot

from auto_typing_final.finder import find_definitions_in_scope_grouped_by_name
from auto_typing_final.finder import find_definitions_in_module

TYPING_FINAL = "typing.Final"
TYPING_FINAL_ANNOTATION_REGEX = re.compile(r"typing\.Final\[(.*)\]{1}")
Expand Down Expand Up @@ -112,19 +112,15 @@ def make_edits_from_operation(operation: Operation) -> Iterable[Edit]: # noqa:
yield node.replace(f"{left}: {new_annotation[0]} = {right}")


def make_edits_for_all_assignments_in_scope(node: SgNode) -> Iterable[Edit]:
for assignments in find_definitions_in_scope_grouped_by_name(node):
yield from make_edits_from_operation(make_operation_from_assignments_to_one_name(assignments))


def make_edits_for_all_functions(root: SgNode) -> Iterable[Edit]:
for function in root.find_all(kind="function_definition"):
yield from make_edits_for_all_assignments_in_scope(function)
def make_edits_for_definitions(definitions: Iterable[list[SgNode]]) -> Iterable[Edit]:
for current_definitions in definitions:
yield from make_edits_from_operation(make_operation_from_assignments_to_one_name(current_definitions))


def run_fixer(source: str) -> str:
root = SgRoot(source, "python").root()
return root.commit_edits(list(make_edits_for_all_functions(root)))
edits = list(make_edits_for_definitions(find_definitions_in_module(root)))
return root.commit_edits(edits)


def main() -> None: # pragma: no cover
Expand Down
58 changes: 51 additions & 7 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest
from ast_grep_py import SgRoot

from auto_typing_final.main import make_edits_for_all_assignments_in_scope, run_fixer
from auto_typing_final.finder import find_definitions_in_scope_grouped_by_name
from auto_typing_final.main import make_edits_for_definitions, run_fixer


@pytest.mark.parametrize(
Expand Down Expand Up @@ -38,15 +39,18 @@
)
def test_variants(before: str, after: str) -> None:
root = SgRoot(before.strip(), "python").root()
assert root.commit_edits(list(make_edits_for_all_assignments_in_scope(root))) == after.strip()
assert (
root.commit_edits(list(make_edits_for_definitions(find_definitions_in_scope_grouped_by_name(root).values())))
== after.strip()
)


# fmt: off
scopes_cases = [
("""
a = 1
""", """
a = 1
a: typing.Final = 1
"""),

("""
Expand All @@ -66,7 +70,7 @@ def foo():
def bar():
a = 3
""", """
a = 1
a: typing.Final = 1

def foo():
a: typing.Final = 2
Expand All @@ -76,7 +80,7 @@ def bar():
"""),

("""
a = 1
a: typing.Final = 1

def foo():
global a
Expand Down Expand Up @@ -485,7 +489,7 @@ def foo():
nonlocal a
""", """
def foo():
a: typing.Final = 1
a = 1
nonlocal a
"""),

Expand Down Expand Up @@ -515,7 +519,7 @@ def foo():
global a
""", """
def foo():
a: typing.Final = 1
a = 1
global a
"""),

Expand All @@ -539,6 +543,46 @@ def foo():
global b
"""),

("""
def foo():
a: typing.Final = 1
b: typing.Final = 2
c: typing.Final = 3

def bar():
nonlocal a
b: typing.Final = 4
c: typing.Final = 5

class C:
a = 6
c = 7

def baz():
nonlocal a, b
b: typing.Final = 8
c: typing.Final = 9
""", """
def foo():
a = 1
b: typing.Final = 2
c: typing.Final = 3

def bar():
nonlocal a
b = 4
c: typing.Final = 5

class C:
a = 6
c = 7

def baz():
nonlocal a, b
b = 8
c: typing.Final = 9
"""),

("""
def foo():
foo: typing.Final = 1
Expand Down