Skip to content

Import typing if annotation has been added and it is not imported yet & add --check to CLI & return 1 if changes have been made #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions auto_typing_final/finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,20 @@ def find_identifiers_in_function_parameter(node: SgNode) -> Iterable[SgNode]:
yield from find_identifiers_in_children(node)


def find_identifiers_in_function_body(node: SgNode) -> Iterable[SgNode]: # noqa: C901, PLR0912, PLR0915
def find_identifiers_in_import(node: SgNode) -> Iterable[SgNode]:
match tuple((child.kind(), child) for child in node.children()):
case (("from", _), _, ("import", _), *name_nodes) | (("import", _), *name_nodes):
for kind, name_node in name_nodes:
match kind:
case "dotted_name":
if identifier := last_child_of_type(name_node, "identifier"):
yield identifier
case "aliased_import":
if alias := name_node.field("alias"):
yield alias


def find_identifiers_in_function_body(node: SgNode) -> Iterable[SgNode]: # noqa: C901, PLR0912
match node.kind():
case "assignment" | "augmented_assignment":
if not (left := node.field("left")):
Expand Down Expand Up @@ -92,16 +105,7 @@ def find_identifiers_in_function_body(node: SgNode) -> Iterable[SgNode]: # noqa
continue
yield from find_identifiers_in_children(nonlocal_statement)
case "import_from_statement" | "import_statement":
match tuple((child.kind(), child) for child in node.children()):
case (("from", _), _, ("import", _), *name_nodes) | (("import", _), *name_nodes):
for kind, name_node in name_nodes:
match kind:
case "dotted_name":
if identifier := last_child_of_type(name_node, "identifier"):
yield identifier
case "aliased_import":
if alias := name_node.field("alias"):
yield alias
yield from find_identifiers_in_import(node)
case "as_pattern":
match tuple((child.kind(), child) for child in node.children()):
case (
Expand Down Expand Up @@ -172,3 +176,15 @@ def find_definitions_in_module(root: SgNode) -> Iterable[list[SgNode]]:
for function in root.find_all(kind="function_definition"):
yield from find_definitions_in_scope_grouped_by_name(function).values()
yield from find_definitions_in_global_scope(root).values()


def has_global_import_with_name(root: SgNode, name: str) -> bool:
for import_statement in root.find_all(
{"rule": {"any": [{"kind": "import_from_statement"}, {"kind": "import_statement"}]}}
):
if is_inside_inner_function_or_class(root, import_statement):
continue
for identifier in find_identifiers_in_import(import_statement):
if identifier.text() == name:
return True
return False
26 changes: 19 additions & 7 deletions auto_typing_final/main.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,31 @@
import argparse
import sys
from difflib import ndiff
from typing import TextIO, cast

from auto_typing_final.transform import transform_file_content


def main() -> None: # pragma: no cover
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("files", type=argparse.FileType("r+"), nargs="*")
parser.add_argument("--check", action="store_true")
args = parser.parse_args()

for file in cast(list[TextIO], parser.parse_args().files):
has_changes = False

for file in cast(list[TextIO], args.files):
data = file.read()
file.seek(0)
file.write(transform_file_content(data))
file.truncate()
transformed_content = transform_file_content(data)

if args.check:
sys.stdout.writelines(ndiff(data.splitlines(keepends=True), transformed_content.splitlines(keepends=True)))
else:
file.seek(0)
file.write(transformed_content)
file.truncate()

if data != transformed_content:
has_changes = True

if __name__ == "__main__": # pragma: no cover
main()
return has_changes
22 changes: 16 additions & 6 deletions auto_typing_final/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ast_grep_py import Edit, SgNode, SgRoot

from auto_typing_final.finder import find_definitions_in_module
from auto_typing_final.finder import find_definitions_in_module, has_global_import_with_name

TYPING_FINAL = "typing.Final"
TYPING_FINAL_ANNOTATION_REGEX = re.compile(r"typing\.Final\[(.*)\]{1}")
Expand Down Expand Up @@ -109,12 +109,22 @@ def make_edits_from_operation(operation: Operation) -> Iterable[Edit]: # noqa:
yield node.replace(f"{left}: {new_annotation[0]} = {right}")


def make_edits_for_definitions(definitions: Iterable[list[SgNode]]) -> Iterable[Edit]:
for current_definitions in definitions:
yield from make_edits_from_operation(make_operation_from_assignments_to_one_name(current_definitions))
def make_edits_for_module(root: SgNode) -> list[Edit]:
edits: list[Edit] = []
has_added_final = False

for current_definitions in find_definitions_in_module(root):
operation = make_operation_from_assignments_to_one_name(current_definitions)
if isinstance(operation, AddFinal):
has_added_final = True
edits.extend(make_edits_from_operation(operation))

if has_added_final and not has_global_import_with_name(root, "typing"):
edits.append(root.replace(f"import typing\n{root.text()}"))

return edits


def transform_file_content(source: str) -> str:
root = SgRoot(source, "python").root()
edits = list(make_edits_for_definitions(find_definitions_in_module(root)))
return root.commit_edits(edits)
return root.commit_edits(make_edits_for_module(root))
Loading