-
Notifications
You must be signed in to change notification settings - Fork 20
Add check for second value in sum: Logsumexp #90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
a254c04
ca6853c
8a28ec0
beca6d7
3637b27
0db07d1
cf8e693
e0ee6ff
0c2edab
708eee4
2762520
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,5 @@ | ||
6:5 TOR108 Use numerically stabilized `torch.logsumexp`. | ||
7:5 TOR108 Use numerically stabilized `torch.logsumexp`. | ||
8:5 TOR108 Use numerically stabilized `torch.logsumexp`. | ||
9:5 TOR108 Use numerically stabilized `torch.logsumexp`. | ||
10:5 TOR108 Use numerically stabilized `torch.logsumexp`. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -184,9 +184,46 @@ def visit_Call(self, node): | |
) | ||
== "torch.exp" | ||
): | ||
self.add_violation( | ||
node, | ||
error_code=self.ERRORS[0].error_code, | ||
message=self.ERRORS[0].message(), | ||
replacement=None, | ||
) | ||
if ( | ||
shivam096 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.get_qualified_name_for_call(node.args[0].value) | ||
== "torch.sum" | ||
): | ||
if ( | ||
self.get_qualified_name_for_call( | ||
node.args[0].value.args[0].value | ||
) | ||
== "torch.exp" | ||
): | ||
dim_arg = self.get_specific_arg( | ||
node.args[0].value, arg_name="dim", arg_pos=1 | ||
) | ||
if dim_arg: # checks if dim argument is present | ||
if isinstance( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These lines are redundant, no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. Removed it since they were test code. |
||
dim_arg.value, cst.Integer | ||
) or isinstance( | ||
dim_arg.value, cst.Tuple | ||
): # checks if dim argument is an integer or tuple | ||
if ( | ||
isinstance(dim_arg.value, cst.Integer) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here the condition checks if the value is of type integer and also makes sure that the value it holds is also not None since Tuples in |
||
and dim_arg.value.value != "None" | ||
): | ||
self.add_violation( | ||
node, | ||
error_code=self.ERRORS[0].error_code, | ||
message=self.ERRORS[0].message(), | ||
replacement=None, | ||
) | ||
elif isinstance( | ||
dim_arg.value, cst.Tuple | ||
) and all( | ||
isinstance(element.value, cst.Integer) | ||
and element.value.value != "None" | ||
for element in dim_arg.value.elements | ||
): # checks if all elements of the | ||
# tuple are not None | ||
self.add_violation( | ||
node, | ||
error_code=self.ERRORS[0].error_code, | ||
message=self.ERRORS[0].message(), | ||
replacement=None, | ||
) |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to check for
dim=(None,None)
ordim=(1,None)
, it can not happen because if presentdim
is an int or tuple of ints: https://pytorch.org/docs/stable/generated/torch.sum.html