diff --git a/tests/fixtures/misc/checker/logsumexp.py b/tests/fixtures/misc/checker/logsumexp.py index 6473f99..cd44e8d 100644 --- a/tests/fixtures/misc/checker/logsumexp.py +++ b/tests/fixtures/misc/checker/logsumexp.py @@ -3,12 +3,19 @@ b = torch.randn(5) # logsumexp -y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True)) -y = torch.log(torch.sum(torch.exp(2.5 + x), 1)) +y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True)) # has all the arguments for sum function call with keepdim as True +y = torch.log(torch.sum(torch.exp(2.5 + x), 1)) # addition operation inside the exp function call +y = torch.log(torch.sum(torch.exp(x),dim=1,keepdim=True)) # has all the arguments for sum function call +y = torch.log(torch.sum(torch.exp(x), dim=1)) #default value of keepdim is False +y = torch.log(torch.sum(torch.exp(x), dim=(1,2))) #default value of keepdim is False # not logsumexp -y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True) + 2.5) -y = torch.log(torch.sum(torch.exp(x) + 2.5, 1)) -y = torch.log(2 + x) -y = torch.sum(torch.log(torch.exp(x)), 1) -y = torch.exp(torch.sum(torch.log(x), 1, keepdim=True)) +y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True) + 2.5) # cant have an addition operation inside the log function call +y = torch.log(torch.sum(torch.exp(x) + 2.5, 1)) # Cant have an addition operation inside the sum function call with the argument as it expects a tensor +y = torch.log(2 + x) # missing sum and exp +y = torch.sum(torch.log(torch.exp(x)), 1) # not proper order of log and sum +y = torch.exp(torch.sum(torch.log(x), 1, keepdim=True)) #order of log,sum and exp is reversed +y = torch.log(torch.sum(torch.exp(2.5))) # this should not be flagged as the second argument is missing for sum function call and exp function call has an integer argument instead of a tensor +y = torch.log(torch.sum(torch.exp(x)), dim=1) #dim is not part of the sum fuction call +y = torch.log(torch.sum(torch.exp(x)), dim=None) #dim is not part of the sum fuction call and dim is None +y = torch.log(torch.sum(torch.exp(x), keepdim=True, dim=None)) #dim argument cannot be None \ No newline at end of file diff --git a/tests/fixtures/misc/checker/logsumexp.txt b/tests/fixtures/misc/checker/logsumexp.txt index 4a4f5ec..697a633 100644 --- a/tests/fixtures/misc/checker/logsumexp.txt +++ b/tests/fixtures/misc/checker/logsumexp.txt @@ -1,2 +1,5 @@ 6:5 TOR108 Use numerically stabilized `torch.logsumexp`. 7:5 TOR108 Use numerically stabilized `torch.logsumexp`. +8:5 TOR108 Use numerically stabilized `torch.logsumexp`. +9:5 TOR108 Use numerically stabilized `torch.logsumexp`. +10:5 TOR108 Use numerically stabilized `torch.logsumexp`. \ No newline at end of file diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index e77de4f..652440c 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -184,9 +184,32 @@ def visit_Call(self, node): ) == "torch.exp" ): - self.add_violation( - node, - error_code=self.ERRORS[0].error_code, - message=self.ERRORS[0].message(), - replacement=None, + dim_arg = self.get_specific_arg( + node.args[0].value, arg_name="dim", arg_pos=1 ) + if dim_arg: # checks if dim argument is present + if isinstance(dim_arg.value, cst.Integer) or isinstance( + dim_arg.value, cst.Tuple + ): # checks if dim argument is an integer or tuple + if ( + isinstance(dim_arg.value, cst.Integer) + and dim_arg.value.value != "None" + ): + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=None, + ) + elif isinstance(dim_arg.value, cst.Tuple) and all( + isinstance(element.value, cst.Integer) + and element.value.value != "None" + for element in dim_arg.value.elements + ): # checks if all elements of the + # tuple are not None + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=None, + )