From a254c041387e865f546456be0c1e64cbf5dcc025 Mon Sep 17 00:00:00 2001 From: shivam096 Date: Thu, 23 Jan 2025 13:13:13 -0800 Subject: [PATCH 01/11] Add check for second value in sum: Logsumexp --- tests/fixtures/misc/checker/logsumexp.py | 1 + torchfix/visitors/misc/__init__.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/fixtures/misc/checker/logsumexp.py b/tests/fixtures/misc/checker/logsumexp.py index 6473f99..b4309b5 100644 --- a/tests/fixtures/misc/checker/logsumexp.py +++ b/tests/fixtures/misc/checker/logsumexp.py @@ -12,3 +12,4 @@ y = torch.log(2 + x) y = torch.sum(torch.log(torch.exp(x)), 1) y = torch.exp(torch.sum(torch.log(x), 1, keepdim=True)) +y = torch.log(torch.sum(torch.exp(2.5))) diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index e77de4f..8f6c57b 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -183,7 +183,7 @@ def visit_Call(self, node): node.args[0].value.args[0].value ) == "torch.exp" - ): + ) and len(node.args[0].value.args) > 1 and node.args[0].value.args[1].value is not None: self.add_violation( node, error_code=self.ERRORS[0].error_code, From ca6853c1e2d2821b3e84fff3dcad46c264c2caab Mon Sep 17 00:00:00 2001 From: shivam096 Date: Fri, 24 Jan 2025 12:16:02 -0800 Subject: [PATCH 02/11] Add condition for dim consideration --- tests/fixtures/misc/checker/logsumexp.py | 6 +++++- tests/fixtures/misc/checker/logsumexp.txt | 2 ++ torchfix/visitors/misc/__init__.py | 18 +++++++++++------- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/fixtures/misc/checker/logsumexp.py b/tests/fixtures/misc/checker/logsumexp.py index b4309b5..b3bfdbd 100644 --- a/tests/fixtures/misc/checker/logsumexp.py +++ b/tests/fixtures/misc/checker/logsumexp.py @@ -5,6 +5,8 @@ # logsumexp y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True)) y = torch.log(torch.sum(torch.exp(2.5 + x), 1)) +y = torch.log(torch.sum(torch.exp(1),1,1)) +y = torch.log(torch.sum(torch.exp(1),1,dim=1)) # not logsumexp y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True) + 2.5) @@ -12,4 +14,6 @@ y = torch.log(2 + x) y = torch.sum(torch.log(torch.exp(x)), 1) y = torch.exp(torch.sum(torch.log(x), 1, keepdim=True)) -y = torch.log(torch.sum(torch.exp(2.5))) +y = torch.log(torch.sum(torch.exp(2.5))) # this should not be flagged as the second argument is missing for sum function call +y = torch.sum(torch.log(torch.exp(x)), dim=1) +y = torch.sum(torch.log(torch.exp(x)), dim=None) \ No newline at end of file diff --git a/tests/fixtures/misc/checker/logsumexp.txt b/tests/fixtures/misc/checker/logsumexp.txt index 4a4f5ec..1ad6c3a 100644 --- a/tests/fixtures/misc/checker/logsumexp.txt +++ b/tests/fixtures/misc/checker/logsumexp.txt @@ -1,2 +1,4 @@ 6:5 TOR108 Use numerically stabilized `torch.logsumexp`. 7:5 TOR108 Use numerically stabilized `torch.logsumexp`. +8:5 TOR108 Use numerically stabilized `torch.logsumexp`. +9:5 TOR108 Use numerically stabilized `torch.logsumexp`. \ No newline at end of file diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index 8f6c57b..28a82bb 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -183,10 +183,14 @@ def visit_Call(self, node): node.args[0].value.args[0].value ) == "torch.exp" - ) and len(node.args[0].value.args) > 1 and node.args[0].value.args[1].value is not None: - self.add_violation( - node, - error_code=self.ERRORS[0].error_code, - message=self.ERRORS[0].message(), - replacement=None, - ) + ): + if len(node.args[0].value.args) > 1 and ( + node.args[0].value.args[1].value is not None + or self.has_specific_arg(node.args[0].value, "dim", -1) + ): + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=None, + ) From 8a28ec04bc60ea06e9701a63548d597af1ac14c3 Mon Sep 17 00:00:00 2001 From: shivam096 Date: Fri, 31 Jan 2025 11:57:50 -0800 Subject: [PATCH 03/11] Update logsumexp condition --- tests/fixtures/misc/checker/logsumexp.py | 2 ++ tests/fixtures/misc/checker/logsumexp.txt | 3 ++- torchfix/visitors/misc/__init__.py | 5 ++++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/fixtures/misc/checker/logsumexp.py b/tests/fixtures/misc/checker/logsumexp.py index b3bfdbd..e738fd5 100644 --- a/tests/fixtures/misc/checker/logsumexp.py +++ b/tests/fixtures/misc/checker/logsumexp.py @@ -7,9 +7,11 @@ y = torch.log(torch.sum(torch.exp(2.5 + x), 1)) y = torch.log(torch.sum(torch.exp(1),1,1)) y = torch.log(torch.sum(torch.exp(1),1,dim=1)) +y = torch.log(torch.sum(torch.exp(x), keepdim=True, dim=None)) # not logsumexp y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True) + 2.5) + y = torch.log(torch.sum(torch.exp(x) + 2.5, 1)) y = torch.log(2 + x) y = torch.sum(torch.log(torch.exp(x)), 1) diff --git a/tests/fixtures/misc/checker/logsumexp.txt b/tests/fixtures/misc/checker/logsumexp.txt index 1ad6c3a..697a633 100644 --- a/tests/fixtures/misc/checker/logsumexp.txt +++ b/tests/fixtures/misc/checker/logsumexp.txt @@ -1,4 +1,5 @@ 6:5 TOR108 Use numerically stabilized `torch.logsumexp`. 7:5 TOR108 Use numerically stabilized `torch.logsumexp`. 8:5 TOR108 Use numerically stabilized `torch.logsumexp`. -9:5 TOR108 Use numerically stabilized `torch.logsumexp`. \ No newline at end of file +9:5 TOR108 Use numerically stabilized `torch.logsumexp`. +10:5 TOR108 Use numerically stabilized `torch.logsumexp`. \ No newline at end of file diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index 28a82bb..380fe23 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -185,7 +185,10 @@ def visit_Call(self, node): == "torch.exp" ): if len(node.args[0].value.args) > 1 and ( - node.args[0].value.args[1].value is not None + self.get_specific_arg( + node.args[0].value, arg_name="dim", arg_pos=1 + ) + is not None or self.has_specific_arg(node.args[0].value, "dim", -1) ): self.add_violation( From beca6d704e2c24a717da96bd0fd530633c70d027 Mon Sep 17 00:00:00 2001 From: shivam096 Date: Fri, 31 Jan 2025 12:01:40 -0800 Subject: [PATCH 04/11] Reformat condition --- tests/fixtures/misc/checker/logsumexp.py | 1 - torchfix/visitors/misc/__init__.py | 8 ++------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/fixtures/misc/checker/logsumexp.py b/tests/fixtures/misc/checker/logsumexp.py index e738fd5..329aef6 100644 --- a/tests/fixtures/misc/checker/logsumexp.py +++ b/tests/fixtures/misc/checker/logsumexp.py @@ -11,7 +11,6 @@ # not logsumexp y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True) + 2.5) - y = torch.log(torch.sum(torch.exp(x) + 2.5, 1)) y = torch.log(2 + x) y = torch.sum(torch.log(torch.exp(x)), 1) diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index 380fe23..d45be73 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -184,12 +184,8 @@ def visit_Call(self, node): ) == "torch.exp" ): - if len(node.args[0].value.args) > 1 and ( - self.get_specific_arg( - node.args[0].value, arg_name="dim", arg_pos=1 - ) - is not None - or self.has_specific_arg(node.args[0].value, "dim", -1) + if self.get_specific_arg( + node.args[0].value, arg_name="dim", arg_pos=1 ): self.add_violation( node, From 3637b274506f7e794561379071f841bf6eb5fab7 Mon Sep 17 00:00:00 2001 From: shivam096 Date: Fri, 31 Jan 2025 14:08:07 -0800 Subject: [PATCH 05/11] Add and update conditions for logsumexp --- tests/fixtures/misc/checker/logsumexp.py | 26 +++++++++++------------ tests/fixtures/misc/checker/logsumexp.txt | 3 +-- torchfix/visitors/misc/__init__.py | 18 ++++++++++------ 3 files changed, 26 insertions(+), 21 deletions(-) diff --git a/tests/fixtures/misc/checker/logsumexp.py b/tests/fixtures/misc/checker/logsumexp.py index 329aef6..4a3cc6c 100644 --- a/tests/fixtures/misc/checker/logsumexp.py +++ b/tests/fixtures/misc/checker/logsumexp.py @@ -3,18 +3,18 @@ b = torch.randn(5) # logsumexp -y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True)) -y = torch.log(torch.sum(torch.exp(2.5 + x), 1)) -y = torch.log(torch.sum(torch.exp(1),1,1)) -y = torch.log(torch.sum(torch.exp(1),1,dim=1)) -y = torch.log(torch.sum(torch.exp(x), keepdim=True, dim=None)) +y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True)) # has all the arguments for sum function call with keepdim as True +y = torch.log(torch.sum(torch.exp(2.5 + x), 1)) # addition operation inside the exp function call +y = torch.log(torch.sum(torch.exp(x),dim=1,keepdim=False)) # has all the arguments for sum function call +y = torch.log(torch.sum(torch.exp(x), dim=1)) #default value of keepdim is False # not logsumexp -y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True) + 2.5) -y = torch.log(torch.sum(torch.exp(x) + 2.5, 1)) -y = torch.log(2 + x) -y = torch.sum(torch.log(torch.exp(x)), 1) -y = torch.exp(torch.sum(torch.log(x), 1, keepdim=True)) -y = torch.log(torch.sum(torch.exp(2.5))) # this should not be flagged as the second argument is missing for sum function call -y = torch.sum(torch.log(torch.exp(x)), dim=1) -y = torch.sum(torch.log(torch.exp(x)), dim=None) \ No newline at end of file +y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True) + 2.5) # cant have an addition operation inside the log function call +y = torch.log(torch.sum(torch.exp(x) + 2.5, 1)) # Cant have an addition operation inside the sum function call with the argument as it expects a tensor +y = torch.log(2 + x) # missing sum and exp +y = torch.sum(torch.log(torch.exp(x)), 1) # not proper order of log and sum +y = torch.exp(torch.sum(torch.log(x), 1, keepdim=True)) #order of log,sum and exp is reversed +y = torch.log(torch.sum(torch.exp(2.5))) # this should not be flagged as the second argument is missing for sum function call and exp function call has an integer argument instead of a tensor +y = torch.log(torch.sum(torch.exp(x)), dim=1) #dim is not part of the sum fuction call +y = torch.log(torch.sum(torch.exp(x)), dim=None) #dim is not part of the sum fuction call and dim is None +y = torch.log(torch.sum(torch.exp(x), keepdim=True, dim=None)) #dim argument cannot be None \ No newline at end of file diff --git a/tests/fixtures/misc/checker/logsumexp.txt b/tests/fixtures/misc/checker/logsumexp.txt index 697a633..1ad6c3a 100644 --- a/tests/fixtures/misc/checker/logsumexp.txt +++ b/tests/fixtures/misc/checker/logsumexp.txt @@ -1,5 +1,4 @@ 6:5 TOR108 Use numerically stabilized `torch.logsumexp`. 7:5 TOR108 Use numerically stabilized `torch.logsumexp`. 8:5 TOR108 Use numerically stabilized `torch.logsumexp`. -9:5 TOR108 Use numerically stabilized `torch.logsumexp`. -10:5 TOR108 Use numerically stabilized `torch.logsumexp`. \ No newline at end of file +9:5 TOR108 Use numerically stabilized `torch.logsumexp`. \ No newline at end of file diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index d45be73..d441472 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -187,9 +187,15 @@ def visit_Call(self, node): if self.get_specific_arg( node.args[0].value, arg_name="dim", arg_pos=1 ): - self.add_violation( - node, - error_code=self.ERRORS[0].error_code, - message=self.ERRORS[0].message(), - replacement=None, - ) + if ( + self.get_specific_arg( + node.args[0].value, arg_name="dim", arg_pos=1 + ).value.value + != "None" + ): + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=None, + ) From 0db07d194b1e35e3f4e060248d6990934664376a Mon Sep 17 00:00:00 2001 From: shivam096 Date: Fri, 31 Jan 2025 14:15:47 -0800 Subject: [PATCH 06/11] Update conditions for logsumexp --- tests/fixtures/misc/checker/logsumexp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/fixtures/misc/checker/logsumexp.py b/tests/fixtures/misc/checker/logsumexp.py index 4a3cc6c..f8bc6e2 100644 --- a/tests/fixtures/misc/checker/logsumexp.py +++ b/tests/fixtures/misc/checker/logsumexp.py @@ -5,7 +5,7 @@ # logsumexp y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True)) # has all the arguments for sum function call with keepdim as True y = torch.log(torch.sum(torch.exp(2.5 + x), 1)) # addition operation inside the exp function call -y = torch.log(torch.sum(torch.exp(x),dim=1,keepdim=False)) # has all the arguments for sum function call +y = torch.log(torch.sum(torch.exp(x),dim=1,keepdim=True)) # has all the arguments for sum function call y = torch.log(torch.sum(torch.exp(x), dim=1)) #default value of keepdim is False # not logsumexp From cf8e6939624826f8d0d538de76b93444a40f752c Mon Sep 17 00:00:00 2001 From: shivam096 Date: Fri, 31 Jan 2025 16:12:29 -0800 Subject: [PATCH 07/11] Add conditions for tuple in dim --- tests/fixtures/misc/checker/logsumexp.py | 5 ++- tests/fixtures/misc/checker/logsumexp.txt | 3 +- torchfix/visitors/misc/__init__.py | 49 ++++++++++++++++++----- 3 files changed, 44 insertions(+), 13 deletions(-) diff --git a/tests/fixtures/misc/checker/logsumexp.py b/tests/fixtures/misc/checker/logsumexp.py index f8bc6e2..332cde6 100644 --- a/tests/fixtures/misc/checker/logsumexp.py +++ b/tests/fixtures/misc/checker/logsumexp.py @@ -7,6 +7,7 @@ y = torch.log(torch.sum(torch.exp(2.5 + x), 1)) # addition operation inside the exp function call y = torch.log(torch.sum(torch.exp(x),dim=1,keepdim=True)) # has all the arguments for sum function call y = torch.log(torch.sum(torch.exp(x), dim=1)) #default value of keepdim is False +y = torch.log(torch.sum(torch.exp(x), dim=(1,2))) #default value of keepdim is False # not logsumexp y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True) + 2.5) # cant have an addition operation inside the log function call @@ -17,4 +18,6 @@ y = torch.log(torch.sum(torch.exp(2.5))) # this should not be flagged as the second argument is missing for sum function call and exp function call has an integer argument instead of a tensor y = torch.log(torch.sum(torch.exp(x)), dim=1) #dim is not part of the sum fuction call y = torch.log(torch.sum(torch.exp(x)), dim=None) #dim is not part of the sum fuction call and dim is None -y = torch.log(torch.sum(torch.exp(x), keepdim=True, dim=None)) #dim argument cannot be None \ No newline at end of file +y = torch.log(torch.sum(torch.exp(x), keepdim=True, dim=None)) #dim argument cannot be None +y = torch.log(torch.sum(torch.exp(x), dim=(1,None))) #dim argument cannot be a tuple with None +y = torch.log(torch.sum(torch.exp(x), dim=(None,None))) #dim argument cannot be a tuple with None \ No newline at end of file diff --git a/tests/fixtures/misc/checker/logsumexp.txt b/tests/fixtures/misc/checker/logsumexp.txt index 1ad6c3a..697a633 100644 --- a/tests/fixtures/misc/checker/logsumexp.txt +++ b/tests/fixtures/misc/checker/logsumexp.txt @@ -1,4 +1,5 @@ 6:5 TOR108 Use numerically stabilized `torch.logsumexp`. 7:5 TOR108 Use numerically stabilized `torch.logsumexp`. 8:5 TOR108 Use numerically stabilized `torch.logsumexp`. -9:5 TOR108 Use numerically stabilized `torch.logsumexp`. \ No newline at end of file +9:5 TOR108 Use numerically stabilized `torch.logsumexp`. +10:5 TOR108 Use numerically stabilized `torch.logsumexp`. \ No newline at end of file diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index d441472..d90944b 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -184,18 +184,45 @@ def visit_Call(self, node): ) == "torch.exp" ): - if self.get_specific_arg( - node.args[0].value, arg_name="dim", arg_pos=1 + if ( + self.get_qualified_name_for_call(node.args[0].value) + == "torch.sum" ): if ( - self.get_specific_arg( - node.args[0].value, arg_name="dim", arg_pos=1 - ).value.value - != "None" + self.get_qualified_name_for_call( + node.args[0].value.args[0].value + ) + == "torch.exp" ): - self.add_violation( - node, - error_code=self.ERRORS[0].error_code, - message=self.ERRORS[0].message(), - replacement=None, + dim_arg = self.get_specific_arg( + node.args[0].value, arg_name="dim", arg_pos=1 ) + if dim_arg: # checks if dim argument is present + if isinstance( + dim_arg.value, cst.Integer + ) or isinstance( + dim_arg.value, cst.Tuple + ): # checks if dim argument is an integer or tuple + if ( + isinstance(dim_arg.value, cst.Integer) + and dim_arg.value.value != "None" + ): + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=None, + ) + elif isinstance( + dim_arg.value, cst.Tuple + ) and all( + isinstance(element.value, cst.Integer) + and element.value.value != "None" + for element in dim_arg.value.elements + ): # checks if all elements of the tuple are not None + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=None, + ) From e0ee6ff76ebba14780d8fd29eb9f804f600bcd9f Mon Sep 17 00:00:00 2001 From: shivam096 Date: Fri, 31 Jan 2025 18:07:23 -0800 Subject: [PATCH 08/11] Resolve flake8 error --- torchfix/visitors/misc/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index d90944b..dd02013 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -219,7 +219,8 @@ def visit_Call(self, node): isinstance(element.value, cst.Integer) and element.value.value != "None" for element in dim_arg.value.elements - ): # checks if all elements of the tuple are not None + ): # checks if all elements of the + # tuple are not None self.add_violation( node, error_code=self.ERRORS[0].error_code, From 0c2edab341b71033dd85a997a9b41fa31c9d53f3 Mon Sep 17 00:00:00 2001 From: shivam096 Date: Mon, 3 Feb 2025 18:35:04 -0800 Subject: [PATCH 09/11] Remove redundant code --- torchfix/visitors/misc/__init__.py | 72 ++++++++++++------------------ 1 file changed, 29 insertions(+), 43 deletions(-) diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index dd02013..652440c 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -184,46 +184,32 @@ def visit_Call(self, node): ) == "torch.exp" ): - if ( - self.get_qualified_name_for_call(node.args[0].value) - == "torch.sum" - ): - if ( - self.get_qualified_name_for_call( - node.args[0].value.args[0].value - ) - == "torch.exp" - ): - dim_arg = self.get_specific_arg( - node.args[0].value, arg_name="dim", arg_pos=1 - ) - if dim_arg: # checks if dim argument is present - if isinstance( - dim_arg.value, cst.Integer - ) or isinstance( - dim_arg.value, cst.Tuple - ): # checks if dim argument is an integer or tuple - if ( - isinstance(dim_arg.value, cst.Integer) - and dim_arg.value.value != "None" - ): - self.add_violation( - node, - error_code=self.ERRORS[0].error_code, - message=self.ERRORS[0].message(), - replacement=None, - ) - elif isinstance( - dim_arg.value, cst.Tuple - ) and all( - isinstance(element.value, cst.Integer) - and element.value.value != "None" - for element in dim_arg.value.elements - ): # checks if all elements of the - # tuple are not None - self.add_violation( - node, - error_code=self.ERRORS[0].error_code, - message=self.ERRORS[0].message(), - replacement=None, - ) + dim_arg = self.get_specific_arg( + node.args[0].value, arg_name="dim", arg_pos=1 + ) + if dim_arg: # checks if dim argument is present + if isinstance(dim_arg.value, cst.Integer) or isinstance( + dim_arg.value, cst.Tuple + ): # checks if dim argument is an integer or tuple + if ( + isinstance(dim_arg.value, cst.Integer) + and dim_arg.value.value != "None" + ): + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=None, + ) + elif isinstance(dim_arg.value, cst.Tuple) and all( + isinstance(element.value, cst.Integer) + and element.value.value != "None" + for element in dim_arg.value.elements + ): # checks if all elements of the + # tuple are not None + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=None, + ) From 708eee48a4e55fdfa0d46948c0f9ac6bf8823424 Mon Sep 17 00:00:00 2001 From: shivam096 Date: Mon, 3 Feb 2025 18:38:24 -0800 Subject: [PATCH 10/11] Remove invalid testcase --- tests/fixtures/misc/checker/logsumexp.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/fixtures/misc/checker/logsumexp.py b/tests/fixtures/misc/checker/logsumexp.py index 332cde6..ff989e9 100644 --- a/tests/fixtures/misc/checker/logsumexp.py +++ b/tests/fixtures/misc/checker/logsumexp.py @@ -19,5 +19,4 @@ y = torch.log(torch.sum(torch.exp(x)), dim=1) #dim is not part of the sum fuction call y = torch.log(torch.sum(torch.exp(x)), dim=None) #dim is not part of the sum fuction call and dim is None y = torch.log(torch.sum(torch.exp(x), keepdim=True, dim=None)) #dim argument cannot be None -y = torch.log(torch.sum(torch.exp(x), dim=(1,None))) #dim argument cannot be a tuple with None -y = torch.log(torch.sum(torch.exp(x), dim=(None,None))) #dim argument cannot be a tuple with None \ No newline at end of file +y = torch.log(torch.sum(torch.exp(x), dim=(1,None))) #dim argument cannot be a tuple with None \ No newline at end of file From 276252079f4c2bf35d5b40a9e55760d6644dc58c Mon Sep 17 00:00:00 2001 From: shivam096 Date: Mon, 3 Feb 2025 18:38:51 -0800 Subject: [PATCH 11/11] Remove invalid testcase --- tests/fixtures/misc/checker/logsumexp.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/fixtures/misc/checker/logsumexp.py b/tests/fixtures/misc/checker/logsumexp.py index ff989e9..cd44e8d 100644 --- a/tests/fixtures/misc/checker/logsumexp.py +++ b/tests/fixtures/misc/checker/logsumexp.py @@ -18,5 +18,4 @@ y = torch.log(torch.sum(torch.exp(2.5))) # this should not be flagged as the second argument is missing for sum function call and exp function call has an integer argument instead of a tensor y = torch.log(torch.sum(torch.exp(x)), dim=1) #dim is not part of the sum fuction call y = torch.log(torch.sum(torch.exp(x)), dim=None) #dim is not part of the sum fuction call and dim is None -y = torch.log(torch.sum(torch.exp(x), keepdim=True, dim=None)) #dim argument cannot be None -y = torch.log(torch.sum(torch.exp(x), dim=(1,None))) #dim argument cannot be a tuple with None \ No newline at end of file +y = torch.log(torch.sum(torch.exp(x), keepdim=True, dim=None)) #dim argument cannot be None \ No newline at end of file