Skip to content

Make it work on real code :) #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jun 25, 2024
2 changes: 1 addition & 1 deletion Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ publish:
uv tool run twine upload dist/* --username __token__ --password $PYPI_TOKEN

run *args:
@uv run add-typing-final {{ args }}
@.venv/bin/add-typing-final {{ args }}
128 changes: 128 additions & 0 deletions add_typing_final/finder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from collections import defaultdict
from collections.abc import Iterable

from ast_grep_py import Config, SgNode


def last_child_of_type(node: SgNode, type_: str) -> SgNode | None:
return last_child if (children := node.children()) and (last_child := children[-1]).kind() == type_ else None


def texts_of_identifier_nodes(node: SgNode) -> Iterable[str]:
return (child.text() for child in node.children() if child.kind() == "identifier")


def find_identifiers_in_function_body(node: SgNode) -> Iterable[str]: # noqa: C901, PLR0912
match node.kind():
case "assignment" | "augmented_assignment":
if (left := node.field("left")):
match left.kind():
case "pattern_list" | "tuple_pattern":
yield from texts_of_identifier_nodes(left)
case "identifier":
yield left.text()
case "function_definition" | "class_definition" | "named_expression":
if name := node.field("name"):
yield name.text()
case "import_from_statement":
match tuple((child.kind(), child) for child in node.children()):
case (("from", _), _, ("import", _), *name_nodes):
for _, child in name_nodes:
match child.kind():
case "dotted_name":
if identifier := last_child_of_type(child, "identifier"):
yield identifier.text()
case "aliased_import":
if alias := child.field("alias"):
yield alias.text()
case "as_pattern":
match tuple((child.kind(), child) for child in node.children()):
case (
(("identifier", _), ("as", _), ("as_pattern_target", alias))
| (("case_pattern", _), ("as", _), ("identifier", alias))
):
yield alias.text()
case "keyword_pattern":
match tuple((child.kind(), child) for child in node.children()):
case (("identifier", _), ("=", _), ("dotted_name", alias)):
if identifier := last_child_of_type(alias, "identifier"):
yield identifier.text()
case "splat_pattern":
yield from texts_of_identifier_nodes(node)
case "dict_pattern":
for child in node.children():
if (
child.kind() == "case_pattern"
and (previous_child := child.prev())
and previous_child.kind() == ":"
and (last_child := last_child_of_type(child, "dotted_name"))
and (last_last_child := last_child_of_type(last_child, "identifier"))
):
yield last_last_child.text()
case "for_statement":
if left := node.field("left"):
for child in left.find_all(kind="identifier"):
yield child.text()


def find_identifiers_in_function_parameter(node: SgNode) -> Iterable[str]:
match node.kind():
case "default_parameter" | "typed_default_parameter":
if name := node.field("name"):
yield name.text()
case "identifier":
yield node.text()
case _:
yield from texts_of_identifier_nodes(node)


rule: Config = {
"rule": {
"any": [
{"kind": "assignment"},
{"kind": "augmented_assignment"},
{"kind": "named_expression"},
{"kind": "function_definition"},
{"kind": "global_statement"},
{"kind": "nonlocal_statement"},
{"kind": "class_definition"},
{"kind": "import_from_statement"},
{"kind": "as_pattern"},
{"kind": "keyword_pattern"},
{"kind": "splat_pattern"},
{"kind": "dict_pattern"},
{"kind": "for_statement"},
]
}
}


def node_is_in_inner_function_or_class(root: SgNode, node: SgNode) -> bool:
for ancestor in node.ancestors():
if ancestor.kind() in {"function_definition", "class_definition"}:
return ancestor != root
return False


def find_definitions_in_scope_grouped_by_name(root: SgNode) -> Iterable[list[SgNode]]:
definition_map = defaultdict(list)
ignored_names = set[str]()
if parameters := root.field("parameters"):
for node in parameters.children():
for identifier in find_identifiers_in_function_parameter(node):
definition_map[identifier].append(node)

for node in root.find_all(rule):
if node_is_in_inner_function_or_class(root, node):
continue
match node.kind():
case "global_statement" | "nonlocal_statement":
ignored_names.update(texts_of_identifier_nodes(node))
case _:
for identifier in find_identifiers_in_function_body(node):
definition_map[identifier].append(node)

for param in ignored_names:
if param in definition_map:
del definition_map[param]
return definition_map.values()
72 changes: 34 additions & 38 deletions add_typing_final/main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import argparse
import re
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from typing import TextIO, cast

from ast_grep_py import Edit, SgNode, SgRoot

from add_typing_final.finder import find_definitions_in_scope_grouped_by_name

# https://github.com/tree-sitter/tree-sitter-python/blob/71778c2a472ed00a64abf4219544edbf8e4b86d7/grammar.js


Expand All @@ -29,48 +30,61 @@ class AssignmentWithAnnotation:
right: str


Assignment = AssignmentWithoutAnnotation | AssignmentWithAnnotation
@dataclass
class OtherDefinition:
node: SgNode


Definition = AssignmentWithoutAnnotation | AssignmentWithAnnotation | OtherDefinition


@dataclass
class AddFinal:
assignment: Assignment
definition: Definition


@dataclass
class RemoveFinal:
nodes: list[Assignment]
nodes: list[Definition]


Operation = AddFinal | RemoveFinal


def is_in_loop(node: SgNode) -> bool:
return any(ancestor.kind() in {"for_statement", "while_statement"} for ancestor in node.ancestors())


def make_operation_from_assignments_to_one_name(nodes: list[SgNode]) -> Operation:
value_assignments: list[Assignment] = []
value_assignments: list[Definition] = []

for node in nodes:
children = node.children()

match tuple(child.kind() for child in children):
case ("identifier", "=", _):
value_assignments.append(
AssignmentWithoutAnnotation(node=node, left=children[0].text(), right=children[2].text())
)
case ("identifier", ":", "type", "=", _):
value_assignments.append(
AssignmentWithAnnotation(
node=node, left=children[0].text(), annotation=children[2].text(), right=children[4].text()
if node.kind() == "assignment" and not is_in_loop(node):
match tuple(child.kind() for child in children):
case ("identifier", "=", _):
value_assignments.append(
AssignmentWithoutAnnotation(node=node, left=children[0].text(), right=children[2].text())
)
)

case ("identifier", ":", "type", "=", _):
value_assignments.append(
AssignmentWithAnnotation(
node=node, left=children[0].text(), annotation=children[2].text(), right=children[4].text()
)
)
case _:
value_assignments.append(OtherDefinition(node))
else:
value_assignments.append(OtherDefinition(node))
match value_assignments:
case [assignment]:
return AddFinal(assignment)
case assignments:
return RemoveFinal(assignments)


def convert_edits_from_operation(operation: Operation) -> Iterable[Edit]: # noqa: C901
def make_edits_from_operation(operation: Operation) -> Iterable[Edit]: # noqa: C901
match operation:
case AddFinal(assignment):
match assignment:
Expand All @@ -93,32 +107,14 @@ def convert_edits_from_operation(operation: Operation) -> Iterable[Edit]: # noq
yield node.replace(f"{left}: {new_annotation[0]} = {right}")


def find_assignments_not_in_function_or_class(node: SgNode) -> Iterable[SgNode]:
if node.kind() == "assignment":
yield node
else:
for child in node.children():
if child.kind() not in {"function_definition", "class_definition"}:
yield from find_assignments_not_in_function_or_class(child)


def find_assignments_grouped_by_name(node: SgNode) -> Iterable[list[SgNode]]:
assignment_map: defaultdict[str, list[SgNode]] = defaultdict(list)
for child in find_assignments_not_in_function_or_class(node):
if left := child.field("left"):
assignment_map[left.text()].append(child)
return assignment_map.values()


def make_edits_for_all_assignments_in_scope(node: SgNode) -> Iterable[Edit]:
for assignments in find_assignments_grouped_by_name(node):
yield from convert_edits_from_operation(make_operation_from_assignments_to_one_name(assignments))
for assignments in find_definitions_in_scope_grouped_by_name(node):
yield from make_edits_from_operation(make_operation_from_assignments_to_one_name(assignments))


def make_edits_for_all_functions(root: SgNode) -> Iterable[Edit]:
for function in root.find_all(kind="function_definition"):
if body := function.field("body"):
yield from make_edits_for_all_assignments_in_scope(body)
yield from make_edits_for_all_assignments_in_scope(function)


def run_fixer(source: str) -> str:
Expand Down
18 changes: 14 additions & 4 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
("b = 1\nb = 2\na = 3", "b = 1\nb = 2\na: typing.Final = 3"),
("a = 1\nb = 2\nb = 3", "a: typing.Final = 1\nb = 2\nb = 3"),
("a = 1\na = 2\nb: int", "a = 1\na = 2\nb: int"),
("a = 1\na: int", "a: typing.Final = 1\na: int"),
("a: int\na = 1", "a: int\na: typing.Final = 1"),
("a: typing.Final\na = 1", "a: typing.Final\na: typing.Final = 1"),
("a: int\na: int = 1", "a: int\na: typing.Final[int] = 1"),
("a = 1\na: int", "a = 1\na: int"),
("a: int\na = 1", "a: int\na = 1"),
("a: typing.Final\na = 1", "a: typing.Final\na = 1"),
("a: int\na: int = 1", "a: int\na: int = 1"),
("a, b = 1, 2", "a, b = 1, 2"),
("(a, b) = 1, 2", "(a, b) = 1, 2"),
("(a, b) = t()", "(a, b) = t()"),
Expand All @@ -44,6 +44,7 @@ def test_variants(before: str, after: str) -> None:
scopes_case = (
"""
a = 1
b, c = 1
MUTABLE_FIRST = 1
MUTABLE_FIRST = 2

Expand Down Expand Up @@ -82,6 +83,10 @@ def duplicated(self) -> None:

def second() -> whatever:
hi = "hi"
for _ in ...:
me = 1
ih = 0
ih += 1

class C:
@t(a=1)
Expand Down Expand Up @@ -114,6 +119,7 @@ def fifth() -> None:
""",
"""
a = 1
b, c = 1
MUTABLE_FIRST = 1
MUTABLE_FIRST = 2

Expand Down Expand Up @@ -152,6 +158,10 @@ def duplicated(self) -> None:

def second() -> whatever:
hi: typing.Final = "hi"
for _ in ...:
me = 1
ih = 0
ih += 1

class C:
@t(a=1)
Expand Down