Skip to content

Add rules for datasets and transforms imports #61

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions tests/fixtures/vision/checker/models_import.txt

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
import torchvision.models
from torchvision.models import *
import torchvision.models as models, torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
4 changes: 4 additions & 0 deletions tests/fixtures/vision/checker/singleton_import.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
1:1 TOR203 Consider replacing 'import torchvision.models as models' with 'from torchvision import models'.
6:1 TOR203 Consider replacing 'import torchvision.models as models' with 'from torchvision import models'.
7:1 TOR203 Consider replacing 'import torchvision.datasets as datasets' with 'from torchvision import datasets'.
8:1 TOR203 Consider replacing 'import torchvision.transforms as transforms' with 'from torchvision import transforms'.
5 changes: 0 additions & 5 deletions tests/fixtures/vision/codemod/models_import.py

This file was deleted.

5 changes: 0 additions & 5 deletions tests/fixtures/vision/codemod/models_import.py.out

This file was deleted.

9 changes: 9 additions & 0 deletions tests/fixtures/vision/codemod/singleton_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torchvision.models as models
import torchvision.models as cnn
import torchvision.datasets as datasets
import torchvision.datasets as datasets_alt
import torchvision.transforms as transforms
import torchvision.transforms as transforms_alt

# don't touch if more than one name imported
import torchvision.models as models, torch
9 changes: 9 additions & 0 deletions tests/fixtures/vision/codemod/singleton_import.py.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from torchvision import models
import torchvision.models as cnn
from torchvision import datasets
import torchvision.datasets as datasets_alt
from torchvision import transforms
import torchvision.transforms as transforms_alt

# don't touch if more than one name imported
import torchvision.models as models, torch
4 changes: 2 additions & 2 deletions torchfix/torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .visitors.vision import (
TorchVisionDeprecatedPretrainedVisitor,
TorchVisionDeprecatedToTensorVisitor,
TorchVisionModelsImportVisitor,
TorchVisionSingletonImportVisitor,
)
from .visitors.security import TorchUnsafeLoadVisitor

Expand All @@ -33,7 +33,7 @@
TorchSynchronizedDataLoaderVisitor,
TorchVisionDeprecatedPretrainedVisitor,
TorchVisionDeprecatedToTensorVisitor,
TorchVisionModelsImportVisitor,
TorchVisionSingletonImportVisitor,
TorchUnsafeLoadVisitor,
TorchReentrantCheckpointVisitor,
TorchNonPublicAliasVisitor,
Expand Down
2 changes: 1 addition & 1 deletion torchfix/visitors/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .pretrained import TorchVisionDeprecatedPretrainedVisitor # noqa: F401
from .to_tensor import TorchVisionDeprecatedToTensorVisitor # noqa: F401
from .models_import import TorchVisionModelsImportVisitor # noqa: F401
from .singleton_import import TorchVisionSingletonImportVisitor # noqa: F401
42 changes: 0 additions & 42 deletions torchfix/visitors/vision/models_import.py

This file was deleted.

60 changes: 60 additions & 0 deletions torchfix/visitors/vision/singleton_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import libcst as cst
import libcst.matchers as m

from ...common import TorchError, TorchVisitor


class TorchVisionSingletonImportVisitor(TorchVisitor):
ERRORS = [
TorchError(
"TOR203",
(
"Consider replacing 'import torchvision.datasets as datasets' "
"with 'from torchvision import datasets'."
),
),
TorchError(
"TOR203",
(
"Consider replacing 'import torchvision.models as models' "
"with 'from torchvision import models'."
),
),
TorchError(
"TOR203",
(
"Consider replacing 'import torchvision.transforms as transforms' "
"with 'from torchvision import transforms'."
),
),
]

# Keep attr order in sync with ERRORS.
REPLACEABLE_ATTRS = ["datasets", "models", "transforms"]

def visit_Import(self, node: cst.Import) -> None:
replacement = None
for i, import_attr in enumerate(self.REPLACEABLE_ATTRS):
for imported_item in node.names:
if m.matches(
imported_item,
m.ImportAlias(
name=m.Attribute(
value=m.Name("torchvision"), attr=m.Name(import_attr)
),
asname=m.AsName(name=m.Name(import_attr)),
),
):
# Replace only if the import statement has no other names
if len(node.names) == 1:
replacement = cst.ImportFrom(
module=cst.Name("torchvision"),
names=[cst.ImportAlias(name=cst.Name(import_attr))],
)
self.add_violation(
node,
error_code=self.ERRORS[i].error_code,
message=self.ERRORS[i].message(),
replacement=replacement,
)
break