From 99f2f6da753563362e140f92d57e94f87c1ae7cc Mon Sep 17 00:00:00 2001 From: shivam096 Date: Wed, 5 Feb 2025 17:12:54 -0800 Subject: [PATCH 1/3] Reimplementation of GradNotSetToNonePattern from Torchtidy --- .../fixtures/performance/checker/zerograd.py | 16 +++++++++ .../fixtures/performance/checker/zerograd.txt | 2 ++ torchfix/torchfix.py | 2 ++ torchfix/visitors/__init__.py | 6 +++- torchfix/visitors/performance/__init__.py | 33 +++++++++++++++++++ 5 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 tests/fixtures/performance/checker/zerograd.py create mode 100644 tests/fixtures/performance/checker/zerograd.txt diff --git a/tests/fixtures/performance/checker/zerograd.py b/tests/fixtures/performance/checker/zerograd.py new file mode 100644 index 0000000..8f0d6fc --- /dev/null +++ b/tests/fixtures/performance/checker/zerograd.py @@ -0,0 +1,16 @@ +import torch +import torch.nn as nn + +x = torch.ones((100, 100)) +model = nn.Sequential() +optimizer = torch.optim.Adam(model.parameters()) + +# This should raise flags +optimizer.zero_grad(set_to_none=False) +model.zero_grad(set_to_none=False) + +# This should not raise flags +optimizer.zero_grad() +model.zero_grad() + + diff --git a/tests/fixtures/performance/checker/zerograd.txt b/tests/fixtures/performance/checker/zerograd.txt new file mode 100644 index 0000000..ed29bf4 --- /dev/null +++ b/tests/fixtures/performance/checker/zerograd.txt @@ -0,0 +1,2 @@ +9:1 TOR402 Detected gradient set to zero instead of None. Please add 'set_to_none=True' when calling zero_grad(). +10:1 TOR402 Detected gradient set to zero instead of None. Please add 'set_to_none=True' when calling zero_grad(). \ No newline at end of file diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index dae1a24..5e96e38 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -21,6 +21,7 @@ TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, TorchVisionSingletonImportVisitor, + TorchGradNotSetToNonePatternVisitor, ) __version__ = "0.7.0" @@ -43,6 +44,7 @@ TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, TorchVisionSingletonImportVisitor, + TorchGradNotSetToNonePatternVisitor, ] diff --git a/torchfix/visitors/__init__.py b/torchfix/visitors/__init__.py index 5317d1b..45f2438 100644 --- a/torchfix/visitors/__init__.py +++ b/torchfix/visitors/__init__.py @@ -8,7 +8,10 @@ TorchRequireGradVisitor, ) from .nonpublic import TorchNonPublicAliasVisitor -from .performance import TorchSynchronizedDataLoaderVisitor +from .performance import ( + TorchSynchronizedDataLoaderVisitor, + TorchGradNotSetToNonePatternVisitor, +) from .security import TorchUnsafeLoadVisitor from .vision import ( TorchVisionDeprecatedPretrainedVisitor, @@ -30,4 +33,5 @@ "TorchVisionDeprecatedPretrainedVisitor", "TorchVisionDeprecatedToTensorVisitor", "TorchVisionSingletonImportVisitor", + "TorchGradNotSetToNonePatternVisitor", ] diff --git a/torchfix/visitors/performance/__init__.py b/torchfix/visitors/performance/__init__.py index 249df4c..6a89202 100644 --- a/torchfix/visitors/performance/__init__.py +++ b/torchfix/visitors/performance/__init__.py @@ -32,3 +32,36 @@ def visit_Call(self, node): error_code=self.ERRORS[0].error_code, message=self.ERRORS[0].message(), ) + + +class TorchGradNotSetToNonePatternVisitor(TorchVisitor): + """ + Reimplementation of GradNotSetToNonePattern from + https://github.com/pytorch/pytorch/blob/main/torch/profiler/_pattern_matcher.py + """ + + ERRORS = [ + TorchError( + "TOR402", + ( + "Detected gradient set to zero instead of None. " + "Please add 'set_to_none=True' when calling zero_grad()." + ), + ) + ] + + def visit_Call(self, node): + qualified_name = self.get_qualified_name_for_call(node) + + if qualified_name and qualified_name.endswith("zero_grad"): + + set_to_none_arg = self.get_specific_arg(node, "set_to_none", 0) + + # hasattr check to handle mypy error + if set_to_none_arg and hasattr(set_to_none_arg.value, "value"): + if set_to_none_arg.value == "False": + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + ) From 3f28da06ab05770a8d0f9a4000e61b30a6fea486 Mon Sep 17 00:00:00 2001 From: shivam096 Date: Wed, 5 Feb 2025 20:02:05 -0800 Subject: [PATCH 2/3] Update linting --- torchfix/visitors/performance/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchfix/visitors/performance/__init__.py b/torchfix/visitors/performance/__init__.py index 6a89202..0558af5 100644 --- a/torchfix/visitors/performance/__init__.py +++ b/torchfix/visitors/performance/__init__.py @@ -59,7 +59,7 @@ def visit_Call(self, node): # hasattr check to handle mypy error if set_to_none_arg and hasattr(set_to_none_arg.value, "value"): - if set_to_none_arg.value == "False": + if set_to_none_arg.value.value == "False": self.add_violation( node, error_code=self.ERRORS[0].error_code, From 8b5a773e5998b0aa6883c1350a64d68f4bae6e99 Mon Sep 17 00:00:00 2001 From: shivam096 Date: Thu, 6 Feb 2025 11:34:32 -0800 Subject: [PATCH 3/3] Reimplementation of OptimizerSingleTensorPattern from Torchtidy --- .../performance/checker/singletensor.py | 17 +++++++++ .../performance/checker/singletensor.txt | 3 ++ .../fixtures/performance/checker/zerograd.py | 2 +- torchfix/torchfix.py | 2 + torchfix/visitors/__init__.py | 2 + torchfix/visitors/performance/__init__.py | 38 +++++++++++++++++++ 6 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 tests/fixtures/performance/checker/singletensor.py create mode 100644 tests/fixtures/performance/checker/singletensor.txt diff --git a/tests/fixtures/performance/checker/singletensor.py b/tests/fixtures/performance/checker/singletensor.py new file mode 100644 index 0000000..0e1a6cb --- /dev/null +++ b/tests/fixtures/performance/checker/singletensor.py @@ -0,0 +1,17 @@ +import torch +import torch.nn as nn + +x = torch.ones((100, 100)) +model = nn.Sequential() + + +# These should raise flags +optimizer_adam = torch.optim.Adam(model.parameters()) +optimizer_sgd = torch.optim.SGD(model.parameters(), lr=0.01) +optimizer_adamw = torch.optim.AdamW(model.parameters()) + +# These should not raise flags +optimizer_adam = torch.optim.Adam(model.parameters(), foreach=True) +optimizer_sgd = torch.optim.SGD(model.parameters(), lr=0.01, foreach=True) +optimizer_adamw = torch.optim.AdamW(model.parameters(), foreach=True) +optimizer_adamw = torch.optim.AdamW(model.parameters(), foreach=False) \ No newline at end of file diff --git a/tests/fixtures/performance/checker/singletensor.txt b/tests/fixtures/performance/checker/singletensor.txt new file mode 100644 index 0000000..8c4a847 --- /dev/null +++ b/tests/fixtures/performance/checker/singletensor.txt @@ -0,0 +1,3 @@ +9:18 TOR403 Deteced optimizer running with single tensor implementation. Please enable multi tensor implementation by passing 'foreach=True' into optimizer. +10:17 TOR403 Deteced optimizer running with single tensor implementation. Please enable multi tensor implementation by passing 'foreach=True' into optimizer. +11:19 TOR403 Deteced optimizer running with single tensor implementation. Please enable multi tensor implementation by passing 'foreach=True' into optimizer. \ No newline at end of file diff --git a/tests/fixtures/performance/checker/zerograd.py b/tests/fixtures/performance/checker/zerograd.py index 8f0d6fc..9188224 100644 --- a/tests/fixtures/performance/checker/zerograd.py +++ b/tests/fixtures/performance/checker/zerograd.py @@ -3,7 +3,7 @@ x = torch.ones((100, 100)) model = nn.Sequential() -optimizer = torch.optim.Adam(model.parameters()) +optimizer = torch.optim.Adam(model.parameters(),foreach=True) # This should raise flags optimizer.zero_grad(set_to_none=False) diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 5e96e38..11c079c 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -22,6 +22,7 @@ TorchVisionDeprecatedToTensorVisitor, TorchVisionSingletonImportVisitor, TorchGradNotSetToNonePatternVisitor, + TorchOptimizerSingleTensorPatternVisitor, ) __version__ = "0.7.0" @@ -45,6 +46,7 @@ TorchVisionDeprecatedToTensorVisitor, TorchVisionSingletonImportVisitor, TorchGradNotSetToNonePatternVisitor, + TorchOptimizerSingleTensorPatternVisitor, ] diff --git a/torchfix/visitors/__init__.py b/torchfix/visitors/__init__.py index 45f2438..e2a9e3d 100644 --- a/torchfix/visitors/__init__.py +++ b/torchfix/visitors/__init__.py @@ -11,6 +11,7 @@ from .performance import ( TorchSynchronizedDataLoaderVisitor, TorchGradNotSetToNonePatternVisitor, + TorchOptimizerSingleTensorPatternVisitor, ) from .security import TorchUnsafeLoadVisitor from .vision import ( @@ -34,4 +35,5 @@ "TorchVisionDeprecatedToTensorVisitor", "TorchVisionSingletonImportVisitor", "TorchGradNotSetToNonePatternVisitor", + "TorchOptimizerSingleTensorPatternVisitor", ] diff --git a/torchfix/visitors/performance/__init__.py b/torchfix/visitors/performance/__init__.py index 0558af5..d9bac76 100644 --- a/torchfix/visitors/performance/__init__.py +++ b/torchfix/visitors/performance/__init__.py @@ -65,3 +65,41 @@ def visit_Call(self, node): error_code=self.ERRORS[0].error_code, message=self.ERRORS[0].message(), ) + + +class TorchOptimizerSingleTensorPatternVisitor(TorchVisitor): + """ + Reimplementation of OptimizerSingleTensorPattern from + https://github.com/pytorch/pytorch/blob/main/torch/profiler/_pattern_matcher.py + """ + + ERRORS = [ + TorchError( + "TOR403", + ( + "Deteced optimizer running with single tensor implementation. " + "Please enable multi tensor implementation by passing 'foreach=True' " + "into optimizer." + ), + ) + ] + + optimizers_with_foreach = ["Adam", "SGD", "AdamW"] + + def visit_Call(self, node): + + qualified_name = self.get_qualified_name_for_call(node) + + for optimizer in self.optimizers_with_foreach: + + if qualified_name and qualified_name.endswith(f"{optimizer}"): + + foreach_arg = self.get_specific_arg(node, arg_name="foreach", arg_pos=1) + + if foreach_arg is None: + + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + )