Skip to content

Refactor a bit #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 80 additions & 78 deletions auto_typing_final/finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,160 +4,162 @@
from ast_grep_py import Config, SgNode

# https://github.com/tree-sitter/tree-sitter-python/blob/71778c2a472ed00a64abf4219544edbf8e4b86d7/grammar.js
DEFINITION_RULE: Config = {
"rule": {
"any": [
{"kind": "assignment"},
{"kind": "augmented_assignment"},
{"kind": "named_expression"},
{"kind": "function_definition"},
{"kind": "global_statement"},
{"kind": "nonlocal_statement"},
{"kind": "class_definition"},
{"kind": "import_from_statement"},
{"kind": "as_pattern"},
{"kind": "keyword_pattern"},
{"kind": "splat_pattern"},
{"kind": "dict_pattern"},
{"kind": "list_pattern"},
{"kind": "tuple_pattern"},
{"kind": "for_statement"},
]
}
}


def last_child_of_type(node: SgNode, type_: str) -> SgNode | None:
return last_child if (children := node.children()) and (last_child := children[-1]).kind() == type_ else None


def texts_of_identifier_nodes(node: SgNode) -> Iterable[str]:
return (child.text() for child in node.children() if child.kind() == "identifier")
def find_identifiers_in_children(node: SgNode) -> Iterable[SgNode]:
for child in node.children():
if child.kind() == "identifier":
yield child


def node_is_in_inner_function(root: SgNode, node: SgNode) -> bool:
def is_inside_inner_function(root: SgNode, node: SgNode) -> bool:
for ancestor in node.ancestors():
if ancestor.kind() == "function_definition":
return ancestor != root
return False


def node_is_in_inner_function_or_class(root: SgNode, node: SgNode) -> bool:
def is_inside_inner_function_or_class(root: SgNode, node: SgNode) -> bool:
for ancestor in node.ancestors():
if ancestor.kind() in {"function_definition", "class_definition"}:
return ancestor != root
return False


def find_identifiers_in_function_body(node: SgNode) -> Iterable[str]: # noqa: C901, PLR0912, PLR0915
def find_identifiers_in_function_parameter(node: SgNode) -> Iterable[SgNode]:
match node.kind():
case "default_parameter" | "typed_default_parameter":
if name := node.field("name"):
yield name
case "identifier":
yield node
case _:
yield from find_identifiers_in_children(node)


def find_identifiers_in_function_body(node: SgNode) -> Iterable[SgNode]: # noqa: C901, PLR0912, PLR0915
match node.kind():
case "assignment" | "augmented_assignment":
if left := node.field("left"):
match left.kind():
case "pattern_list" | "tuple_pattern":
yield from texts_of_identifier_nodes(left)
case "identifier":
yield left.text()
if not (left := node.field("left")):
return
match left.kind():
case "pattern_list" | "tuple_pattern":
yield from find_identifiers_in_children(left)
case "identifier":
yield left
case "named_expression":
if name := node.field("name"):
yield name.text()
yield name
case "class_definition":
if name := node.field("name"):
yield name.text()
yield name
for function in node.find_all(kind="function_definition"):
for nonlocal_statement in node.find_all(kind="nonlocal_statement"):
if not node_is_in_inner_function(root=function, node=nonlocal_statement):
yield from texts_of_identifier_nodes(nonlocal_statement)
if is_inside_inner_function(root=function, node=nonlocal_statement):
continue
yield from find_identifiers_in_children(nonlocal_statement)
case "function_definition":
if name := node.field("name"):
yield name.text()
yield name
for nonlocal_statement in node.find_all(kind="nonlocal_statement"):
if not node_is_in_inner_function(root=node, node=nonlocal_statement):
yield from texts_of_identifier_nodes(nonlocal_statement)
if is_inside_inner_function(root=node, node=nonlocal_statement):
continue
yield from find_identifiers_in_children(nonlocal_statement)
case "import_from_statement":
match tuple((child.kind(), child) for child in node.children()):
case (("from", _), _, ("import", _), *name_nodes):
for _, child in name_nodes:
match child.kind():
for name_node_kind, name_node in name_nodes:
match name_node_kind:
case "dotted_name":
if identifier := last_child_of_type(child, "identifier"):
yield identifier.text()
if identifier := last_child_of_type(name_node, "identifier"):
yield identifier
case "aliased_import":
if alias := child.field("alias"):
yield alias.text()
if alias := name_node.field("alias"):
yield alias
case "as_pattern":
match tuple((child.kind(), child) for child in node.children()):
case (
(("identifier", _), ("as", _), ("as_pattern_target", alias))
| (("case_pattern", _), ("as", _), ("identifier", alias))
):
yield alias.text()
yield alias
case "keyword_pattern":
match tuple((child.kind(), child) for child in node.children()):
case (("identifier", _), ("=", _), ("dotted_name", alias)):
if identifier := last_child_of_type(alias, "identifier"):
yield identifier.text()
yield identifier
case "list_pattern" | "tuple_pattern":
for child in node.children():
if (
child.kind() == "case_pattern"
and (last_child := last_child_of_type(child, "dotted_name"))
and (last_last_child := last_child_of_type(last_child, "identifier"))
and (dotted_name := last_child_of_type(child, "dotted_name"))
and (identifier := last_child_of_type(dotted_name, "identifier"))
):
yield last_last_child.text()
yield identifier
case "splat_pattern" | "global_statement" | "nonlocal_statement":
yield from texts_of_identifier_nodes(node)
yield from find_identifiers_in_children(node)
case "dict_pattern":
for child in node.children():
if (
child.kind() == "case_pattern"
and (previous_child := child.prev())
and previous_child.kind() == ":"
and (last_child := last_child_of_type(child, "dotted_name"))
and (last_last_child := last_child_of_type(last_child, "identifier"))
and (dotted_name := last_child_of_type(child, "dotted_name"))
and (identifier := last_child_of_type(dotted_name, "identifier"))
):
yield last_last_child.text()
yield identifier
case "for_statement":
if left := node.field("left"):
for child in left.find_all(kind="identifier"):
yield child.text()


def find_identifiers_in_function_parameter(node: SgNode) -> Iterable[str]:
match node.kind():
case "default_parameter" | "typed_default_parameter":
if name := node.field("name"):
yield name.text()
case "identifier":
yield node.text()
case _:
yield from texts_of_identifier_nodes(node)


rule: Config = {
"rule": {
"any": [
{"kind": "assignment"},
{"kind": "augmented_assignment"},
{"kind": "named_expression"},
{"kind": "function_definition"},
{"kind": "global_statement"},
{"kind": "nonlocal_statement"},
{"kind": "class_definition"},
{"kind": "import_from_statement"},
{"kind": "as_pattern"},
{"kind": "keyword_pattern"},
{"kind": "splat_pattern"},
{"kind": "dict_pattern"},
{"kind": "list_pattern"},
{"kind": "tuple_pattern"},
{"kind": "for_statement"},
]
}
}
yield from left.find_all(kind="identifier")


def find_definitions_in_scope_grouped_by_name(root: SgNode) -> dict[str, list[SgNode]]:
definition_map = defaultdict(list)

if parameters := root.field("parameters"):
for node in parameters.children():
for identifier in find_identifiers_in_function_parameter(node):
definition_map[identifier].append(node)
for parameter in parameters.children():
for identifier in find_identifiers_in_function_parameter(parameter):
definition_map[identifier.text()].append(parameter)

for node in root.find_all(rule):
if node_is_in_inner_function_or_class(root, node) or node == root:
for node in root.find_all(DEFINITION_RULE):
if is_inside_inner_function_or_class(root, node) or node == root:
continue
for identifier in find_identifiers_in_function_body(node):
definition_map[identifier].append(node)
definition_map[identifier.text()].append(node)

return definition_map


def find_definitions_in_global_scope(root: SgNode) -> dict[str, list[SgNode]]:
global_statement_identifiers = defaultdict(list)
for node in root.find_all(kind="global_statement"):
for identifier in texts_of_identifier_nodes(node):
global_statement_identifiers[identifier].append(node)
for global_statement in root.find_all(kind="global_statement"):
for identifier in find_identifiers_in_children(global_statement):
global_statement_identifiers[identifier.text()].append(global_statement)

return {
identifier: (global_statement_identifiers[identifier] + definitions)
Expand Down
123 changes: 2 additions & 121 deletions auto_typing_final/main.py
Original file line number Diff line number Diff line change
@@ -1,126 +1,7 @@
import argparse
import re
from collections.abc import Iterable
from dataclasses import dataclass
from typing import TextIO, cast

from ast_grep_py import Edit, SgNode, SgRoot

from auto_typing_final.finder import find_definitions_in_module

TYPING_FINAL = "typing.Final"
TYPING_FINAL_ANNOTATION_REGEX = re.compile(r"typing\.Final\[(.*)\]{1}")


@dataclass
class AssignmentWithoutAnnotation:
node: SgNode
left: str
right: str


@dataclass
class AssignmentWithAnnotation:
node: SgNode
left: str
annotation: str
right: str


@dataclass
class OtherDefinition:
node: SgNode


Definition = AssignmentWithoutAnnotation | AssignmentWithAnnotation | OtherDefinition


@dataclass
class AddFinal:
definition: Definition


@dataclass
class RemoveFinal:
nodes: list[Definition]


Operation = AddFinal | RemoveFinal


def is_in_loop(node: SgNode) -> bool:
return any(ancestor.kind() in {"for_statement", "while_statement"} for ancestor in node.ancestors())


def make_operation_from_assignments_to_one_name(nodes: list[SgNode]) -> Operation:
value_assignments: list[Definition] = []
has_node_in_loop = False

for node in nodes:
children = node.children()

if is_in_loop(node):
has_node_in_loop = True

if node.kind() == "assignment":
match tuple(child.kind() for child in children):
case ("identifier", "=", _):
value_assignments.append(
AssignmentWithoutAnnotation(node=node, left=children[0].text(), right=children[2].text())
)
case ("identifier", ":", "type", "=", _):
value_assignments.append(
AssignmentWithAnnotation(
node=node, left=children[0].text(), annotation=children[2].text(), right=children[4].text()
)
)
case _:
value_assignments.append(OtherDefinition(node))
else:
value_assignments.append(OtherDefinition(node))

if has_node_in_loop:
return RemoveFinal(value_assignments)

match value_assignments:
case [assignment]:
return AddFinal(assignment)
case assignments:
return RemoveFinal(assignments)


def make_edits_from_operation(operation: Operation) -> Iterable[Edit]: # noqa: C901
match operation:
case AddFinal(assignment):
match assignment:
case AssignmentWithoutAnnotation(node, left, right):
yield node.replace(f"{left}: {TYPING_FINAL} = {right}")
case AssignmentWithAnnotation(node, left, annotation, right):
if TYPING_FINAL in annotation:
return
yield node.replace(f"{left}: {TYPING_FINAL}[{annotation}] = {right}")

case RemoveFinal(assignments):
for assignment in assignments:
match assignment:
case AssignmentWithoutAnnotation(node, left, right):
yield node.replace(f"{left} = {right}")
case AssignmentWithAnnotation(node, left, annotation, right):
if annotation == TYPING_FINAL:
yield node.replace(f"{left} = {right}")
elif new_annotation := TYPING_FINAL_ANNOTATION_REGEX.findall(annotation):
yield node.replace(f"{left}: {new_annotation[0]} = {right}")


def make_edits_for_definitions(definitions: Iterable[list[SgNode]]) -> Iterable[Edit]:
for current_definitions in definitions:
yield from make_edits_from_operation(make_operation_from_assignments_to_one_name(current_definitions))


def run_fixer(source: str) -> str:
root = SgRoot(source, "python").root()
edits = list(make_edits_for_definitions(find_definitions_in_module(root)))
return root.commit_edits(edits)
from auto_typing_final.transform import transform_file_content


def main() -> None: # pragma: no cover
Expand All @@ -130,7 +11,7 @@ def main() -> None: # pragma: no cover
for file in cast(list[TextIO], parser.parse_args().files):
data = file.read()
file.seek(0)
file.write(run_fixer(data))
file.write(transform_file_content(data))
file.truncate()


Expand Down
Loading