Skip to content

Refactor: misc best practices #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions torchfix/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ def StderrSilencer(redirect: bool = True):
libc.close(orig_stderr)


def main() -> None:
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()

parser.add_argument(
"path",
nargs="+",
help=("Path to check/fix. Can be a directory, a file, or multiple of either."),
help="Path to check/fix. Can be a directory, a file, or multiple of either.",
)
parser.add_argument(
"--fix",
Expand Down Expand Up @@ -78,11 +78,11 @@ def main() -> None:
action="store_true",
)

args = parser.parse_args()
return parser.parse_args()

if not args.path:
parser.print_usage()
sys.exit(1)

def main() -> None:
args = _parse_args()

files = codemod.gather_files(args.path)

Expand Down
84 changes: 41 additions & 43 deletions torchfix/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABC
from dataclasses import dataclass
from os.path import commonprefix
from typing import Dict, List, Optional, Sequence, Set, Tuple
from typing import List, Optional, Sequence, Set, Tuple, Mapping

import libcst as cst
from libcst.codemod.visitors import ImportItem
Expand All @@ -26,7 +26,7 @@ class LintViolation:

def flake8_result(self):
full_message = f"{self.error_code} {self.message}"
return (self.line, 1 + self.column, full_message, "TorchFix")
return self.line, 1 + self.column, full_message, "TorchFix"

def codemod_result(self) -> str:
fixable = f" [{CYAN}*{ENDC}]" if self.replacement is not None else ""
Expand Down Expand Up @@ -56,6 +56,7 @@ class TorchVisitor(cst.BatchableCSTVisitor, ABC):
ERRORS: List[TorchError]

def __init__(self) -> None:
super().__init__()
self.violations: List[LintViolation] = []
self.needed_imports: Set[ImportItem] = set()

Expand Down Expand Up @@ -128,8 +129,7 @@ def get_qualified_name_for_call(self, node: cst.Call) -> Optional[str]:
name_metadata = list(self.get_metadata(QualifiedNameProvider, node))
if not name_metadata:
return None
qualified_name = name_metadata[0].name
return qualified_name
return name_metadata[0].name


def call_with_name_changes(
Expand Down Expand Up @@ -157,7 +157,6 @@ def call_with_name_changes(
new_call_name = new_qualified_name.removeprefix(
commonprefix([qualified_name.removesuffix(call_name), new_qualified_name])
)
new_call_name = new_call_name
new_module_name = new_qualified_name.removesuffix(new_call_name).removesuffix(
"."
)
Expand All @@ -175,67 +174,66 @@ def call_with_name_changes(
# Replace with new_qualified_name.
if replacement is None:
return None
else:
return replacement, needed_imports

return replacement, needed_imports


def check_old_names_in_import_from(
node: cst.ImportFrom, old_new_name_map: Dict[str, Optional[str]]
node: cst.ImportFrom, old_new_name_map: Mapping[str, Optional[str]]
) -> Tuple[List[str], Optional[cst.ImportFrom]]:
"""
Using `old_new_name_map`, check if there are any old names in the import from.
Return a tuple of two elements:
1. List of all founds old names.
2. Optional replacement for the ImportFrom node.
"""
if node.module is None:
if node.module is None or not isinstance(node.names, Sequence):
return [], None

old_names: List[str] = []
replacement = None
if isinstance(node.names, Sequence):
new_names: List[str] = []
module = cst.helpers.get_full_name_for_node(node.module)

# `possible_new_modules` and `has_non_updated_names` are used
# to decide if we can replace the ImportFrom node.
new_modules: Set[str] = set()
has_non_updated_names = False

for name in node.names:
qualified_name = f"{module}.{name.name.value}"
if qualified_name in old_new_name_map:
old_names.append(qualified_name)
new_qualified_name = old_new_name_map[qualified_name]
if new_qualified_name is not None:
new_module = ".".join(new_qualified_name.split(".")[:-1])
new_name = new_qualified_name.split(".")[-1]
new_names.append(new_name)
new_modules.add(new_module)
else:
has_non_updated_names = True
new_names: List[str] = []
module = cst.helpers.get_full_name_for_node(node.module)

# `possible_new_modules` and `has_non_updated_names` are used
# to decide if we can replace the ImportFrom node.
new_modules: Set[str] = set()
has_non_updated_names = False

for name in node.names:
qualified_name = f"{module}.{name.name.value}"
if qualified_name in old_new_name_map:
old_names.append(qualified_name)
new_qualified_name = old_new_name_map[qualified_name]
if new_qualified_name is not None:
new_module = ".".join(new_qualified_name.split(".")[:-1])
new_name = new_qualified_name.split(".")[-1]
new_names.append(new_name)
new_modules.add(new_module)
else:
has_non_updated_names = True

# Replace only if the new module is the same for all names in the import.
if not has_non_updated_names and len(new_modules) == 1:
new_module = new_modules.pop()
import_aliases = list(node.names)
for i in range(len(import_aliases)):
import_aliases[i] = import_aliases[i].with_changes(
name=cst.Name(new_names[i])
)
replacement = node.with_changes(
module=cst.parse_expression(new_module), # type: ignore[arg-type] # noqa: E501
names=import_aliases,
)
else:
has_non_updated_names = True

# Replace only if the new module is the same for all names in the import.
if not has_non_updated_names and len(new_modules) == 1:
new_module = new_modules.pop()
import_aliases = [
import_alias.with_changes(name=cst.Name(new_name))
for import_alias, new_name in zip(list(node.names), new_names)
]
replacement = node.with_changes(
module=cst.parse_expression(new_module),
names=import_aliases,
)

return old_names, replacement


def deep_multi_replace(tree, replacement_map):
class MultiChildReplacementTransformer(cst.CSTTransformer):
def __init__(self, replacement_map) -> None:
super().__init__()
self.replacement_map = replacement_map

def on_leave(self, original_node, updated_node):
Expand Down
16 changes: 8 additions & 8 deletions torchfix/torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
import libcst.codemod as codemod

from .common import deep_multi_replace, TorchVisitor
from .visitors.deprecated_symbols import TorchDeprecatedSymbolsVisitor
from .visitors.internal import TorchScopedLibraryVisitor

from .visitors.performance import TorchSynchronizedDataLoaderVisitor
from .visitors.misc import TorchRequireGradVisitor, TorchReentrantCheckpointVisitor
from .visitors.nonpublic import TorchNonPublicAliasVisitor

from .visitors.vision import (
from .visitors import (
TorchDeprecatedSymbolsVisitor,
TorchNonPublicAliasVisitor,
TorchReentrantCheckpointVisitor,
TorchRequireGradVisitor,
TorchScopedLibraryVisitor,
TorchSynchronizedDataLoaderVisitor,
TorchUnsafeLoadVisitor,
TorchVisionDeprecatedPretrainedVisitor,
TorchVisionDeprecatedToTensorVisitor,
TorchVisionSingletonImportVisitor,
)
from .visitors.security import TorchUnsafeLoadVisitor

__version__ = "0.5.0"

Expand Down
24 changes: 24 additions & 0 deletions torchfix/visitors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from .deprecated_symbols import TorchDeprecatedSymbolsVisitor
from .internal import TorchScopedLibraryVisitor
from .misc import TorchReentrantCheckpointVisitor, TorchRequireGradVisitor
from .nonpublic import TorchNonPublicAliasVisitor
from .performance import TorchSynchronizedDataLoaderVisitor
from .security import TorchUnsafeLoadVisitor
from .vision import (
TorchVisionDeprecatedPretrainedVisitor,
TorchVisionDeprecatedToTensorVisitor,
TorchVisionSingletonImportVisitor,
)

__all__ = [
"TorchDeprecatedSymbolsVisitor",
"TorchRequireGradVisitor",
"TorchScopedLibraryVisitor",
"TorchSynchronizedDataLoaderVisitor",
"TorchVisionDeprecatedPretrainedVisitor",
"TorchVisionDeprecatedToTensorVisitor",
"TorchVisionSingletonImportVisitor",
"TorchUnsafeLoadVisitor",
"TorchReentrantCheckpointVisitor",
"TorchNonPublicAliasVisitor",
]
34 changes: 17 additions & 17 deletions torchfix/visitors/deprecated_symbols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ def read_deprecated_config(path=None):

super().__init__()
self.deprecated_config = read_deprecated_config(deprecated_config_path)
self.old_new_name_map = {}
for name in self.deprecated_config:
new_name = self.deprecated_config[name].get("replacement")
self.old_new_name_map[name] = new_name
self.old_new_name_map = {
name: self.deprecated_config[name].get("replacement")
for name in self.deprecated_config
}

def _call_replacement(
self, node: cst.Call, qualified_name: str
Expand All @@ -53,19 +53,19 @@ def _call_replacement(
replacement = None

if qualified_name in replacements_map:
replacement = replacements_map[qualified_name](node)
else:
# Replace names for functions that have drop-in replacement.
function_name_replacement = self.deprecated_config.get(
qualified_name, {}
).get("replacement", "")
if function_name_replacement:
replacement_and_imports = call_with_name_changes(
node, qualified_name, function_name_replacement
)
if replacement_and_imports is not None:
replacement, imports = replacement_and_imports
self.needed_imports.update(imports)
return replacements_map[qualified_name](node)

# Replace names for functions that have drop-in replacement.
function_name_replacement = self.deprecated_config.get(qualified_name, {}).get(
"replacement", ""
)
if function_name_replacement:
replacement_and_imports = call_with_name_changes(
node, qualified_name, function_name_replacement
)
if replacement_and_imports is not None:
replacement, imports = replacement_and_imports
self.needed_imports.update(imports)
return replacement

def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
Expand Down
20 changes: 8 additions & 12 deletions torchfix/visitors/deprecated_symbols/chain_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,17 @@ def call_replacement_chain_matmul(node: cst.Call) -> cst.CSTNode:
Replace `torch.chain_matmul` with `torch.linalg.multi_dot`, changing
multiple parameters to a list.
"""
matrices = []
matrices = [
cst.Element(value=arg.value) for arg in node.args if arg.keyword is None
]
matrices_arg = cst.Arg(value=cst.List(elements=matrices))

out_arg = None
for arg in node.args:
if arg.keyword is None:
matrices.append(cst.Element(value=arg.value))
elif arg.keyword.value == "out":
if arg.keyword is not None and arg.keyword.value == "out":
out_arg = arg
matrices_arg = cst.Arg(value=cst.List(elements=matrices))

if out_arg is None:
replacement_args = [matrices_arg]
else:
replacement_args = [matrices_arg, out_arg]
replacement_args = [matrices_arg] if out_arg is None else [matrices_arg, out_arg]
module_name = get_module_name(node, "torch")
replacement = cst.parse_expression(f"{module_name}.linalg.multi_dot(args)")
replacement = replacement.with_changes(args=replacement_args)

return replacement
return replacement.with_changes(args=replacement_args)
9 changes: 7 additions & 2 deletions torchfix/visitors/deprecated_symbols/cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,14 @@ def call_replacement_cholesky(node: cst.Call) -> cst.CSTNode:
and cst.ensure_type(upper_arg.value, cst.Name).value == "True"
):
replacement = cst.parse_expression(f"{module_name}.linalg.cholesky(A).mH")

# Make mypy happy
assert isinstance(replacement, (cst.Name, cst.Attribute))

old_node = cst.ensure_type(replacement.value, cst.Call).args
replacement = replacement.with_deep_changes(
# Ignore type error, see https://github.com/Instagram/LibCST/issues/963
old_node=cst.ensure_type(replacement.value, cst.Call).args, # type: ignore
# see https://github.com/Instagram/LibCST/issues/963
old_node=old_node, # type: ignore[arg-type]
value=[input_arg],
)
else:
Expand Down
4 changes: 1 addition & 3 deletions torchfix/visitors/deprecated_symbols/qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,4 @@ def call_replacement_qr(node: cst.Call) -> Optional[cst.CSTNode]:
replacement_args = [input_arg]
module_name = get_module_name(node, "torch")
replacement = cst.parse_expression(f"{module_name}.linalg.qr(args)")
replacement = replacement.with_changes(args=replacement_args)

return replacement
return replacement.with_changes(args=replacement_args)
Loading