diff --git a/CHANGELOG.md b/CHANGELOG.md index bea9b72..799ba6a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Add rule GCI102 avoid non pinned memory for dataloader. This rule is specific to PyTorch and so AI +- Add utils file + ### Changed - compatibility updates for SonarQube 25.5.0 diff --git a/src/main/java/org/greencodeinitiative/creedengo/python/PythonRuleRepository.java b/src/main/java/org/greencodeinitiative/creedengo/python/PythonRuleRepository.java index c385979..02a8cc5 100644 --- a/src/main/java/org/greencodeinitiative/creedengo/python/PythonRuleRepository.java +++ b/src/main/java/org/greencodeinitiative/creedengo/python/PythonRuleRepository.java @@ -40,7 +40,8 @@ public class PythonRuleRepository implements RulesDefinition, PythonCustomRuleRe AvoidFullSQLRequest.class, AvoidListComprehensionInIterations.class, DetectUnoptimizedImageFormat.class, - AvoidMultipleIfElseStatementCheck.class + AvoidMultipleIfElseStatementCheck.class, + AvoidNonPinnedMemoryForDataloaders.class ); public static final String LANGUAGE = "py"; diff --git a/src/main/java/org/greencodeinitiative/creedengo/python/checks/AvoidNonPinnedMemoryForDataloaders.java b/src/main/java/org/greencodeinitiative/creedengo/python/checks/AvoidNonPinnedMemoryForDataloaders.java new file mode 100644 index 0000000..29f0455 --- /dev/null +++ b/src/main/java/org/greencodeinitiative/creedengo/python/checks/AvoidNonPinnedMemoryForDataloaders.java @@ -0,0 +1,66 @@ +/* + * creedengo - Python language - Provides rules to reduce the environmental footprint of your Python programs + * Copyright © 2024 Green Code Initiative (https://green-code-initiative.org) + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.greencodeinitiative.creedengo.python.checks; + +import org.sonar.check.Priority; +import org.sonar.check.Rule; +import org.sonar.plugins.python.api.PythonSubscriptionCheck; +import org.sonar.plugins.python.api.tree.Tree; +import org.sonar.plugins.python.api.tree.CallExpression; +import org.sonar.plugins.python.api.tree.RegularArgument; +import org.sonar.plugins.python.api.tree.Expression; +import org.sonar.plugins.python.api.tree.Name; + +import java.util.Optional; + +import static org.sonar.plugins.python.api.tree.Tree.Kind.NAME; + +@Rule(key = "GCI102") +public class AvoidNonPinnedMemoryForDataloaders extends PythonSubscriptionCheck { + + public static final String RULE_KEY = "P3"; + private static final String dataloaderFullyQualifiedName = "torch.utils.data.DataLoader"; + private static final int pinMemoryArgumentPosition = 7; + private static final String pinMemoryArgumentName = "pin_memory"; + protected static final String MESSAGE = "Use pinned memory to reduce data transfer in RAM."; + + @Override + public void initialize(Context context) { + context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, ctx -> { + CallExpression callExpression = (CallExpression) ctx.syntaxNode(); + + if (Utils.getQualifiedName(callExpression).equals(dataloaderFullyQualifiedName)) { + RegularArgument numWorkersArgument = Utils.nthArgumentOrKeyword(pinMemoryArgumentPosition, + pinMemoryArgumentName, + callExpression.arguments()); + + if (numWorkersArgument == null) { + ctx.addIssue(callExpression, MESSAGE); + } else { + Optional.of(numWorkersArgument).filter(this::checkBadValuesForPinMemory) + .ifPresent(arg -> ctx.addIssue(arg, MESSAGE)); + } + } + }); + } + + private boolean checkBadValuesForPinMemory(RegularArgument pinMemoryArgument) { + Expression expression = pinMemoryArgument.expression(); + return expression.is(NAME) && ((Name) expression).name().equals("False"); + } +} \ No newline at end of file diff --git a/src/main/java/org/greencodeinitiative/creedengo/python/checks/Utils.java b/src/main/java/org/greencodeinitiative/creedengo/python/checks/Utils.java new file mode 100644 index 0000000..44f2183 --- /dev/null +++ b/src/main/java/org/greencodeinitiative/creedengo/python/checks/Utils.java @@ -0,0 +1,102 @@ +/* + * creedengo - Python language - Provides rules to reduce the environmental footprint of your Python programs + * Copyright © 2024 Green Code Initiative (https://green-code-initiative.org) + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.greencodeinitiative.creedengo.python.checks; + +import org.sonar.plugins.python.api.SubscriptionContext; +import org.sonar.plugins.python.api.symbols.Symbol; +import org.sonar.plugins.python.api.tree.Argument; +import org.sonar.plugins.python.api.tree.AssignmentStatement; +import org.sonar.plugins.python.api.tree.Tree; +import org.sonar.plugins.python.api.tree.RegularArgument; +import org.sonar.plugins.python.api.tree.Name; +import org.sonar.plugins.python.api.tree.CallExpression; +import org.sonar.plugins.python.api.tree.Expression; + +import javax.annotation.CheckForNull; +import java.util.List; +import java.util.Objects; + +public class Utils { + + private static boolean hasKeyword(Argument argument, String keyword) { + if (!argument.is(new Tree.Kind[] {Tree.Kind.REGULAR_ARGUMENT})) { + return false; + } else { + Name keywordArgument = ((RegularArgument) argument).keywordArgument(); + return keywordArgument != null && keywordArgument.name().equals(keyword); + } + } + + @CheckForNull + public static RegularArgument nthArgumentOrKeyword(int argPosition, String keyword, List arguments) { + for (int i = 0; i < arguments.size(); ++i) { + Argument argument = (Argument) arguments.get(i); + if (hasKeyword(argument, keyword)) { + return (RegularArgument) argument; + } + + if (argument.is(new Tree.Kind[] {Tree.Kind.REGULAR_ARGUMENT})) { + RegularArgument regularArgument = (RegularArgument) argument; + if (regularArgument.keywordArgument() == null && argPosition == i) { + return regularArgument; + } + } + } + + return null; + } + + public static String getQualifiedName(CallExpression callExpression) { + Symbol symbol = callExpression.calleeSymbol(); + + return symbol != null && symbol.fullyQualifiedName() != null ? symbol.fullyQualifiedName() : ""; + } + + public static String getMethodName(CallExpression callExpression) { + Symbol symbol = callExpression.calleeSymbol(); + return symbol != null && symbol.name() != null ? symbol.name() : ""; + } + + public static List getArgumentsFromCall(CallExpression callExpression) { + try { + return Objects.requireNonNull(callExpression.argumentList()).arguments(); + } catch (NullPointerException e) { + return List.of(); + } + } + + public static String getVariableName(SubscriptionContext context) { + Tree node = context.syntaxNode(); + Tree current = node; + while (current != null && !current.is(Tree.Kind.ASSIGNMENT_STMT)) { + current = current.parent(); + } + if (current != null && current.is(Tree.Kind.ASSIGNMENT_STMT)) { + AssignmentStatement assignment = (AssignmentStatement) current; + if (!assignment.lhsExpressions().isEmpty() && !assignment.lhsExpressions().get(0).expressions().isEmpty()) { + Expression leftExpr = assignment.lhsExpressions().get(0).expressions().get(0); + if (leftExpr.is(Tree.Kind.NAME)) { + Name variableName = (Name) leftExpr; + return variableName.name(); + } + } + + } + return null; + } +} diff --git a/src/test/java/org/greencodeinitiative/creedengo/python/checks/AvoidNonPinnedMemoryForDataloadersTest.java b/src/test/java/org/greencodeinitiative/creedengo/python/checks/AvoidNonPinnedMemoryForDataloadersTest.java new file mode 100644 index 0000000..d3e8346 --- /dev/null +++ b/src/test/java/org/greencodeinitiative/creedengo/python/checks/AvoidNonPinnedMemoryForDataloadersTest.java @@ -0,0 +1,29 @@ +/* + * creedengo - Python language - Provides rules to reduce the environmental footprint of your Python programs + * Copyright © 2024 Green Code Initiative (https://green-code-initiative.org) + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.greencodeinitiative.creedengo.python.checks; + +import org.junit.Test; +import org.sonar.python.checks.utils.PythonCheckVerifier; + +public class AvoidNonPinnedMemoryForDataloadersTest { + + @Test + public void test() { + PythonCheckVerifier.verify("src/test/resources/checks/avoidNonPinnedMemoryForDataloaders.py", new AvoidNonPinnedMemoryForDataloaders()); + } +} diff --git a/src/test/resources/checks/avoidNonPinnedMemoryForDataloaders.py b/src/test/resources/checks/avoidNonPinnedMemoryForDataloaders.py new file mode 100644 index 0000000..44373b9 --- /dev/null +++ b/src/test/resources/checks/avoidNonPinnedMemoryForDataloaders.py @@ -0,0 +1,26 @@ +import torch +from torch.utils.data import DataLoader +from torch.utils.data import DataLoader as DL +import torch.utils as utils +import nottorch + +dl = torch.utils.data.DataLoader(dataset) # Noncompliant {{Use pinned memory to reduce data transfer in RAM.}} +dl = torch.utils.data.DataLoader(dataset, num_workers=3, batch_size=1, shuffle=False, pin_memory=False) # Noncompliant {{Use pinned memory to reduce data transfer in RAM.}} +dl = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=True, pin_memory=False) # Noncompliant {{Use pinned memory to reduce data transfer in RAM.}} +dl = torch.utils.data.DataLoader(num_workers=5, batch_size=2, shuffle=True) # Noncompliant {{Use pinned memory to reduce data transfer in RAM.}} +dl = DataLoader(dataset, 1, False, None, None, 0, None, False, False) # Noncompliant {{Use pinned memory to reduce data transfer in RAM.}} +dl = utils.data.DataLoader(dataset, batch_size=1, False, None, None, 0, None, False, True) # Noncompliant {{Use pinned memory to reduce data transfer in RAM.}} +dl = utils.data.DataLoader(dataset, pin_memory=False) # Noncompliant {{Use pinned memory to reduce data transfer in RAM.}} +dl = DL(dataset, pin_memory=False) # Noncompliant {{Use pinned memory to reduce data transfer in RAM.}} + +dl = torch.utils.data.DataLoader(dataset, num_workers=3, batch_size=1, shuffle=False, pin_memory=True) +dl = DataLoader(dataset, num_workers=0, batch_size=1, shuffle=False, pin_memory=True) +dl = DataLoader(dataset, num_workers=0, batch_size=1, shuffle=False, pin_memory=True) +dl = utils.data.DataLoader(dataset, pin_memory=True) +dl = utils.data.DataLoader(dataset, batch_size=1, False, None, None, 0, None, True, True) +dl = DataLoader(dataset, batch_size=1, False, None, None, 0, None, True, False) +dl = torch.utils.data.DataLoader(torchvision.datasets.MNIST('/files/', train=True, download=True), pin_memory=True) +dl = DL(dataset, pin_memory=True) + +dl = nottorch.utils.data.DataLoader(dataset, pin_memory=True) +dl = nottorch.utils.data.DataLoader(dataset, pin_memory=False) \ No newline at end of file