Skip to content

GCI107 UseTorchFromNumpy #AI #Python #DLG #Build #86

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add rule GCI107 Torch from numpy, the rule isn't finished yet

### Changed

- compatibility updates for SonarQube 25.5.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ public class PythonRuleRepository implements RulesDefinition, PythonCustomRuleRe
AvoidFullSQLRequest.class,
AvoidListComprehensionInIterations.class,
DetectUnoptimizedImageFormat.class,
AvoidMultipleIfElseStatementCheck.class
AvoidMultipleIfElseStatementCheck.class,
UseTorchFromNumpy.class
);

public static final String LANGUAGE = "py";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* creedengo - Python language - Provides rules to reduce the environmental footprint of your Python programs
* Copyright © 2024 Green Code Initiative (https://green-code-initiative.org)
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.greencodeinitiative.creedengo.python.checks;

import org.sonar.check.Rule;
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
import org.sonar.plugins.python.api.SubscriptionContext;
import org.sonar.plugins.python.api.tree.*;

import java.util.HashSet;
import java.util.Set;

import static org.sonar.plugins.python.api.tree.Tree.Kind.*;

/**
* Rule to enforce the use of torch.from_numpy() instead of torch.tensor() when working with NumPy arrays.
* This optimization reduces memory usage and computational overhead by avoiding unnecessary data copying.
*/
@Rule(key = "GCI107")
public class UseTorchFromNumpy extends PythonSubscriptionCheck {

public static final String DESCRIPTION = "Use torch.from_numpy() instead of torch.tensor() to create tensors from numpy arrays";
private static final String NUMPY_ARRAY_FUNCTION = "numpy.array";
private static final String TORCH_TENSOR_FUNCTION = "torch.tensor";

private final Set<String> numpyArrayVariables = new HashSet<>();

@Override
public void initialize(Context context) {
context.registerSyntaxNodeConsumer(ASSIGNMENT_STMT, this::visitAssignmentStatement);
context.registerSyntaxNodeConsumer(CALL_EXPR, this::visitCallExpression);
}

private void visitAssignmentStatement(SubscriptionContext ctx) {
var assignmentStmt = (AssignmentStatement) ctx.syntaxNode();
var value = assignmentStmt.assignedValue();

if (value.is(CALL_EXPR) && isNumpyArrayCreation((CallExpression) value)) {
String variableName = Utils.getVariableName(ctx);
if (variableName != null) {
numpyArrayVariables.add(variableName);
}
}
}

private boolean isNumpyArrayCreation(CallExpression callExpression) {
return NUMPY_ARRAY_FUNCTION.equals(Utils.getQualifiedName(callExpression));
}

private void visitCallExpression(SubscriptionContext ctx) {
var callExpression = (CallExpression) ctx.syntaxNode();

if (!TORCH_TENSOR_FUNCTION.equals(Utils.getQualifiedName(callExpression)) && !TORCH_TENSOR_FUNCTION.equals(callExpression.callee().firstToken().value()+"."+callExpression.calleeSymbol().name())) {
return;
}

for (Argument arg : callExpression.arguments()) {
if (!arg.is(REGULAR_ARGUMENT)) {
continue;
}

var regArg = (RegularArgument) arg;
var argumentExpression = regArg.expression();

// Case 1: Direct np.array call in the argument
if (argumentExpression.is(CALL_EXPR)) {
var argCallExpression = (CallExpression) argumentExpression;
if (isNumpyArrayCreation(argCallExpression)) {
ctx.addIssue(argumentExpression, DESCRIPTION);
continue;
}
}

// Case 2: Variable reference to a previously defined numpy array
if (argumentExpression.is(NAME)) {
var name = (Name) argumentExpression;
var variableName = name.name();

if (numpyArrayVariables.contains(variableName)) {
ctx.addIssue(argumentExpression, DESCRIPTION);
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* creedengo - Python language - Provides rules to reduce the environmental footprint of your Python programs
* Copyright © 2024 Green Code Initiative (https://green-code-initiative.org)
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.greencodeinitiative.creedengo.python.checks;

import org.sonar.plugins.python.api.SubscriptionContext;
import org.sonar.plugins.python.api.symbols.Symbol;
import org.sonar.plugins.python.api.tree.Argument;
import org.sonar.plugins.python.api.tree.AssignmentStatement;
import org.sonar.plugins.python.api.tree.Tree;
import org.sonar.plugins.python.api.tree.RegularArgument;
import org.sonar.plugins.python.api.tree.Name;
import org.sonar.plugins.python.api.tree.CallExpression;
import org.sonar.plugins.python.api.tree.Expression;

import javax.annotation.CheckForNull;
import java.util.List;
import java.util.Objects;

public class Utils {

private static boolean hasKeyword(Argument argument, String keyword) {
if (!argument.is(new Tree.Kind[] {Tree.Kind.REGULAR_ARGUMENT})) {
return false;
} else {
Name keywordArgument = ((RegularArgument) argument).keywordArgument();
return keywordArgument != null && keywordArgument.name().equals(keyword);
}
}

@CheckForNull
public static RegularArgument nthArgumentOrKeyword(int argPosition, String keyword, List<Argument> arguments) {
for (int i = 0; i < arguments.size(); ++i) {
Argument argument = (Argument) arguments.get(i);
if (hasKeyword(argument, keyword)) {
return (RegularArgument) argument;
}

if (argument.is(new Tree.Kind[] {Tree.Kind.REGULAR_ARGUMENT})) {
RegularArgument regularArgument = (RegularArgument) argument;
if (regularArgument.keywordArgument() == null && argPosition == i) {
return regularArgument;
}
}
}

return null;
}

public static String getQualifiedName(CallExpression callExpression) {
Symbol symbol = callExpression.calleeSymbol();

return symbol != null && symbol.fullyQualifiedName() != null ? symbol.fullyQualifiedName() : "";
}

public static String getMethodName(CallExpression callExpression) {
Symbol symbol = callExpression.calleeSymbol();
return symbol != null && symbol.name() != null ? symbol.name() : "";
}

public static List<Argument> getArgumentsFromCall(CallExpression callExpression) {
try {
return Objects.requireNonNull(callExpression.argumentList()).arguments();
} catch (NullPointerException e) {
return List.of();
}
}

public static String getVariableName(SubscriptionContext context) {
Tree node = context.syntaxNode();
Tree current = node;
while (current != null && !current.is(Tree.Kind.ASSIGNMENT_STMT)) {
current = current.parent();
}
if (current != null && current.is(Tree.Kind.ASSIGNMENT_STMT)) {
AssignmentStatement assignment = (AssignmentStatement) current;
if (!assignment.lhsExpressions().isEmpty() && !assignment.lhsExpressions().get(0).expressions().isEmpty()) {
Expression leftExpr = assignment.lhsExpressions().get(0).expressions().get(0);
if (leftExpr.is(Tree.Kind.NAME)) {
Name variableName = (Name) leftExpr;
return variableName.name();
}
}

}
return null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* creedengo - Python language - Provides rules to reduce the environmental footprint of your Python programs
* Copyright © 2024 Green Code Initiative (https://green-code-initiative.org)
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.greencodeinitiative.creedengo.python.checks;

import org.junit.Test;
import org.sonar.python.checks.utils.PythonCheckVerifier;

public class UseTorchFromNumpyTest {

@Test
public void test() {
PythonCheckVerifier.verify("src/test/resources/checks/useTorchFromNumpy.py", new UseTorchFromNumpy());
}
}
53 changes: 53 additions & 0 deletions src/test/resources/checks/useTorchFromNumpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import numpy as np
import torch as tt

np_array = np.array([1, 2, 3])


torch_tensor = tt.from_numpy(np_array) # Compliant

torch = tt.tensor(np_array) # Noncompliant {{Use torch.from_numpy() instead of torch.tensor() to create tensors from numpy arrays}}
# Case 1: Standard imports
import numpy
import torch

numpy_array = numpy.array([1, 2, 3, 4])

compliant1 = torch.from_numpy(numpy_array) # Compliant

non_compliant1 = torch.tensor(numpy_array) # Noncompliant {{Use torch.from_numpy() instead of torch.tensor() to create tensors from numpy arrays}}

# Case 2: Aliased imports
import numpy as np
import torch as tt

compliant2 = tt.from_numpy(numpy_array) # Compliant

non_compliant2 = tt.tensor(numpy_array) # Noncompliant {{Use torch.from_numpy() instead of torch.tensor() to create tensors from numpy arrays}}

# Case 3: From imports
from numpy import array
from torch import tensor, from_numpy


compliant3 = from_numpy(numpy_array) # Compliant

non_compliant3 = tensor(numpy_array) # Noncompliant {{Use torch.from_numpy() instead of torch.tensor() to create tensors from numpy arrays}}

# Case 4: From imports with aliases
from numpy import array as np_arr
from torch import tensor as t_tensor, from_numpy as t_from_numpy


compliant4 = t_from_numpy(numpy_array) # Compliant
non_compliant4 = t_tensor(numpy_array) # Noncompliant {{Use torch.from_numpy() instead of torch.tensor() to create tensors from numpy arrays}}

# Case 5: Direct np call as function argument
compliant5 = tt.from_numpy(np.array([1, 2, 3]))
non_compliant5 = tt.tensor(np.array([1, 2, 3])) # Noncompliant {{Use torch.from_numpy() instead of torch.tensor() to create tensors from numpy arrays}}

# Case 6: Alias direct np call as function argument
compliant5 = t_from_numpy(np.array([1, 2, 3]))
non_compliant6 = t_tensor(np.array([1, 2, 3])) # Noncompliant {{Use torch.from_numpy() instead of torch.tensor() to create tensors from numpy arrays}}


Loading