diff --git a/CHANGELOG.md b/CHANGELOG.md
index f97f08d..b868024 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
+- Add rule GCI101 Avoid Conv Bias Before Batch Normalization, a rule specific to Deeplearning
+- Add an utils file to keep all utils functions in one place
+
### Changed
- compatibility updates for SonarQube 25.5.0
diff --git a/src/main/java/org/greencodeinitiative/creedengo/python/PythonRuleRepository.java b/src/main/java/org/greencodeinitiative/creedengo/python/PythonRuleRepository.java
index c385979..e3c7520 100644
--- a/src/main/java/org/greencodeinitiative/creedengo/python/PythonRuleRepository.java
+++ b/src/main/java/org/greencodeinitiative/creedengo/python/PythonRuleRepository.java
@@ -40,7 +40,8 @@ public class PythonRuleRepository implements RulesDefinition, PythonCustomRuleRe
AvoidFullSQLRequest.class,
AvoidListComprehensionInIterations.class,
DetectUnoptimizedImageFormat.class,
- AvoidMultipleIfElseStatementCheck.class
+ AvoidMultipleIfElseStatementCheck.class,
+ AvoidConvBiasBeforeBatchNorm.class
);
public static final String LANGUAGE = "py";
diff --git a/src/main/java/org/greencodeinitiative/creedengo/python/checks/AvoidConvBiasBeforeBatchNorm.java b/src/main/java/org/greencodeinitiative/creedengo/python/checks/AvoidConvBiasBeforeBatchNorm.java
new file mode 100644
index 0000000..e9a3843
--- /dev/null
+++ b/src/main/java/org/greencodeinitiative/creedengo/python/checks/AvoidConvBiasBeforeBatchNorm.java
@@ -0,0 +1,208 @@
+/*
+ * creedengo - Python language - Provides rules to reduce the environmental footprint of your Python programs
+ * Copyright © 2024 Green Code Initiative (https://green-code-initiative.org)
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+package org.greencodeinitiative.creedengo.python.checks;
+
+
+import org.sonar.check.Priority;
+import org.sonar.check.Rule;
+import org.sonar.plugins.python.api.PythonSubscriptionCheck;
+import org.sonar.plugins.python.api.SubscriptionContext;
+import org.sonar.plugins.python.api.symbols.ClassSymbol;
+import org.sonar.plugins.python.api.tree.Name;
+import org.sonar.plugins.python.api.tree.ClassDef;
+import org.sonar.plugins.python.api.tree.CallExpression;
+import org.sonar.plugins.python.api.tree.Tree;
+import org.sonar.plugins.python.api.tree.Expression;
+import org.sonar.plugins.python.api.tree.RegularArgument;
+import org.sonar.plugins.python.api.tree.FunctionDef;
+import org.sonar.plugins.python.api.tree.Argument;
+import org.sonar.plugins.python.api.tree.AssignmentStatement;
+import org.sonar.plugins.python.api.tree.BaseTreeVisitor;
+import org.sonar.plugins.python.api.tree.QualifiedExpression;
+import org.sonar.plugins.python.api.tree.Statement;
+
+import java.util.ArrayList;
+import java.util.Map;
+import java.util.HashMap;
+import java.util.Objects;
+import java.util.Optional;
+
+
+
+import static org.sonar.plugins.python.api.tree.Tree.Kind.ASSIGNMENT_STMT;
+import static org.sonar.plugins.python.api.tree.Tree.Kind.CALL_EXPR;
+import static org.sonar.plugins.python.api.tree.Tree.Kind.NAME;
+import static org.sonar.plugins.python.api.tree.Tree.Kind.REGULAR_ARGUMENT;
+import static org.sonar.plugins.python.api.tree.Tree.Kind.FUNCDEF;
+
+@Rule(key="GCI101")
+
+public class AvoidConvBiasBeforeBatchNorm extends PythonSubscriptionCheck {
+
+ public static final String RULE_KEY = "P4";
+ private static final String nnModuleFullyQualifiedName = "torch.nn.Module";
+ private static final String convFullyQualifiedName = "torch.nn.Conv2d";
+ private static final String forwardMethodName = "forward";
+ private static final String batchNormFullyQualifiedName = "torch.nn.BatchNorm2d";
+ private static final String sequentialModuleFullyQualifiedName = "torch.nn.Sequential";
+ protected static final String MESSAGE = "Remove bias for convolutions before batch norm layers to save time and memory.";
+
+ @Override
+ public void initialize(Context context) {
+ context.registerSyntaxNodeConsumer(Tree.Kind.CLASSDEF, ctx -> {
+ ClassDef classDef = (ClassDef) ctx.syntaxNode();
+ Optional.ofNullable(classDef).filter(this::isModelClass).ifPresent(e -> visitModelClass(ctx, e));
+ });
+ }
+
+ private boolean isConvWithBias(CallExpression convDefinition) {
+ RegularArgument biasArgument = Utils.nthArgumentOrKeyword(7, "bias", convDefinition.arguments());
+ if (biasArgument == null)
+ return true;
+ else {
+ Expression expression = biasArgument.expression();
+ return expression.is(NAME) && ((Name) expression).name().equals("True");
+ }
+ }
+
+ private boolean isModelClass(ClassDef classDef) {
+ ClassSymbol classSymbol = (ClassSymbol) classDef.name().symbol();
+ if (classSymbol != null) {
+ return classSymbol.superClasses().stream().anyMatch(e -> Objects.equals(e.fullyQualifiedName(), nnModuleFullyQualifiedName))
+ && classSymbol.declaredMembers().stream().anyMatch(e -> e.name().equals(forwardMethodName));
+ } else
+ return false;
+ }
+
+ private void reportIfBatchNormIsCalledAfterDirtyConv(SubscriptionContext context, FunctionDef forwardDef, Map dirtyConvInInit,
+ Map batchNormsInInit) {
+ ForwardMethodVisitor visitor = new ForwardMethodVisitor();
+ forwardDef.accept(visitor);
+
+ for (CallExpression callInForward : visitor.callExpressions) {
+ // if it is a batchNorm
+ if (batchNormsInInit.containsKey(Utils.getMethodName(callInForward))) {
+ int batchNormLineNo = callInForward.firstToken().line();
+ for (Argument batchNormArgument : Utils.getArgumentsFromCall(callInForward)) {
+ Expression batchNormArgumentExpression = ((RegularArgument) batchNormArgument).expression();
+ if (batchNormArgumentExpression.is(CALL_EXPR)) {
+ String functionName = Utils.getMethodName((CallExpression) batchNormArgumentExpression);
+ if (dirtyConvInInit.containsKey(functionName)) {
+ context.addIssue(dirtyConvInInit.get(functionName), MESSAGE);
+ }
+
+ // if it uses a variable
+ } else if (batchNormArgumentExpression.is(NAME) && ((Name) batchNormArgumentExpression).isVariable()) {
+ String batchNormArgumentName = ((Name) batchNormArgumentExpression).name();
+
+ // loop through all call expressions in forward
+ AssignmentStatement lastAssignmentStatementBeforeBatchNorm = null;
+
+ for (AssignmentStatement assignmentStatement : visitor.assignmentStatements) {
+ Name variable = (Name) assignmentStatement.lhsExpressions().get(0).expressions().get(0);
+ String variableName = variable.name();
+ if (assignmentStatement.firstToken().line() >= batchNormLineNo)
+ break;
+
+ if (variableName.equals(batchNormArgumentName))
+ lastAssignmentStatementBeforeBatchNorm = assignmentStatement;
+ }
+ if (lastAssignmentStatementBeforeBatchNorm != null && lastAssignmentStatementBeforeBatchNorm.assignedValue().is(CALL_EXPR)) {
+ CallExpression function = (CallExpression) lastAssignmentStatementBeforeBatchNorm.assignedValue();
+ String functionName = Utils.getMethodName(function);
+ if (dirtyConvInInit.containsKey(functionName)) {
+ context.addIssue(dirtyConvInInit.get(functionName), MESSAGE);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ private void reportForSequentialModules(SubscriptionContext context, CallExpression sequentialCall) {
+ int moduleIndex = 0;
+ int nModulesInSequential = Utils.getArgumentsFromCall(sequentialCall).size();
+ while (moduleIndex < nModulesInSequential) {
+ Argument moduleInSequential = Utils.getArgumentsFromCall(sequentialCall).get(moduleIndex);
+ if (moduleInSequential.is(REGULAR_ARGUMENT) && ((RegularArgument) moduleInSequential).expression().is(CALL_EXPR)) {
+ CallExpression module = (CallExpression) ((RegularArgument) moduleInSequential).expression();
+ if (Utils.getQualifiedName(module).equals(convFullyQualifiedName) && isConvWithBias(module)) {
+ if (moduleIndex == nModulesInSequential - 1)
+ break;
+ Argument nextModuleInSequential = Utils.getArgumentsFromCall(sequentialCall).get(moduleIndex + 1);
+ CallExpression nextModule = (CallExpression) ((RegularArgument) nextModuleInSequential).expression();
+ if (Utils.getQualifiedName(nextModule).equals(batchNormFullyQualifiedName))
+ context.addIssue(module, MESSAGE);
+ }
+ }
+ moduleIndex += 1;
+ }
+ }
+
+ private void visitModelClass(SubscriptionContext context, ClassDef classDef) {
+ Map dirtyConvInInit = new HashMap<>();
+ Map batchNormsInInit = new HashMap<>();
+
+ for (Statement s : classDef.body().statements()) {
+ if (s.is(FUNCDEF) && ((FunctionDef) s).name().name().equals("__init__")) {
+ for (Statement ss : ((FunctionDef) s).body().statements()) {
+ if (ss.is(ASSIGNMENT_STMT)) {
+ Expression lhs = ((AssignmentStatement) ss).lhsExpressions().get(0).expressions().get(0);
+ // consider only calls (modules)
+ if (!((AssignmentStatement) ss).assignedValue().is(CALL_EXPR))
+ break;
+ CallExpression callExpression = (CallExpression) ((AssignmentStatement) ss).assignedValue();
+ String variableName = ((QualifiedExpression) lhs).name().name();
+ String variableClass = Utils.getQualifiedName(callExpression);
+ if (variableClass.equals(sequentialModuleFullyQualifiedName)) {
+ reportForSequentialModules(context, callExpression);
+ } else if (convFullyQualifiedName.contains(variableClass) && isConvWithBias(callExpression)) {
+ dirtyConvInInit.put(variableName, callExpression);
+ } else if (batchNormFullyQualifiedName.contains(variableClass)) {
+ batchNormsInInit.put(variableName, callExpression);
+ }
+ }
+ }
+ }
+ }
+ for (Statement s : classDef.body().statements()) {
+ if (s.is(FUNCDEF) && ((FunctionDef) s).name().name().equals(forwardMethodName)) {
+ FunctionDef forwardDef = (FunctionDef) s;
+ reportIfBatchNormIsCalledAfterDirtyConv(context, forwardDef, dirtyConvInInit, batchNormsInInit);
+ }
+ }
+
+ }
+
+ private static class ForwardMethodVisitor extends BaseTreeVisitor {
+ private final ArrayList callExpressions = new ArrayList<>();
+ private final ArrayList assignmentStatements = new ArrayList<>();
+
+ @Override
+ public void visitCallExpression(CallExpression pyCallExpressionTree) {
+ callExpressions.add(pyCallExpressionTree);
+ super.visitCallExpression(pyCallExpressionTree);
+ }
+
+ public void visitAssignmentStatement(AssignmentStatement pyAssignmentStatementTree) {
+ assignmentStatements.add(pyAssignmentStatementTree);
+ super.visitAssignmentStatement(pyAssignmentStatementTree);
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/org/greencodeinitiative/creedengo/python/checks/Utils.java b/src/main/java/org/greencodeinitiative/creedengo/python/checks/Utils.java
new file mode 100644
index 0000000..44f2183
--- /dev/null
+++ b/src/main/java/org/greencodeinitiative/creedengo/python/checks/Utils.java
@@ -0,0 +1,102 @@
+/*
+ * creedengo - Python language - Provides rules to reduce the environmental footprint of your Python programs
+ * Copyright © 2024 Green Code Initiative (https://green-code-initiative.org)
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+package org.greencodeinitiative.creedengo.python.checks;
+
+import org.sonar.plugins.python.api.SubscriptionContext;
+import org.sonar.plugins.python.api.symbols.Symbol;
+import org.sonar.plugins.python.api.tree.Argument;
+import org.sonar.plugins.python.api.tree.AssignmentStatement;
+import org.sonar.plugins.python.api.tree.Tree;
+import org.sonar.plugins.python.api.tree.RegularArgument;
+import org.sonar.plugins.python.api.tree.Name;
+import org.sonar.plugins.python.api.tree.CallExpression;
+import org.sonar.plugins.python.api.tree.Expression;
+
+import javax.annotation.CheckForNull;
+import java.util.List;
+import java.util.Objects;
+
+public class Utils {
+
+ private static boolean hasKeyword(Argument argument, String keyword) {
+ if (!argument.is(new Tree.Kind[] {Tree.Kind.REGULAR_ARGUMENT})) {
+ return false;
+ } else {
+ Name keywordArgument = ((RegularArgument) argument).keywordArgument();
+ return keywordArgument != null && keywordArgument.name().equals(keyword);
+ }
+ }
+
+ @CheckForNull
+ public static RegularArgument nthArgumentOrKeyword(int argPosition, String keyword, List arguments) {
+ for (int i = 0; i < arguments.size(); ++i) {
+ Argument argument = (Argument) arguments.get(i);
+ if (hasKeyword(argument, keyword)) {
+ return (RegularArgument) argument;
+ }
+
+ if (argument.is(new Tree.Kind[] {Tree.Kind.REGULAR_ARGUMENT})) {
+ RegularArgument regularArgument = (RegularArgument) argument;
+ if (regularArgument.keywordArgument() == null && argPosition == i) {
+ return regularArgument;
+ }
+ }
+ }
+
+ return null;
+ }
+
+ public static String getQualifiedName(CallExpression callExpression) {
+ Symbol symbol = callExpression.calleeSymbol();
+
+ return symbol != null && symbol.fullyQualifiedName() != null ? symbol.fullyQualifiedName() : "";
+ }
+
+ public static String getMethodName(CallExpression callExpression) {
+ Symbol symbol = callExpression.calleeSymbol();
+ return symbol != null && symbol.name() != null ? symbol.name() : "";
+ }
+
+ public static List getArgumentsFromCall(CallExpression callExpression) {
+ try {
+ return Objects.requireNonNull(callExpression.argumentList()).arguments();
+ } catch (NullPointerException e) {
+ return List.of();
+ }
+ }
+
+ public static String getVariableName(SubscriptionContext context) {
+ Tree node = context.syntaxNode();
+ Tree current = node;
+ while (current != null && !current.is(Tree.Kind.ASSIGNMENT_STMT)) {
+ current = current.parent();
+ }
+ if (current != null && current.is(Tree.Kind.ASSIGNMENT_STMT)) {
+ AssignmentStatement assignment = (AssignmentStatement) current;
+ if (!assignment.lhsExpressions().isEmpty() && !assignment.lhsExpressions().get(0).expressions().isEmpty()) {
+ Expression leftExpr = assignment.lhsExpressions().get(0).expressions().get(0);
+ if (leftExpr.is(Tree.Kind.NAME)) {
+ Name variableName = (Name) leftExpr;
+ return variableName.name();
+ }
+ }
+
+ }
+ return null;
+ }
+}
diff --git a/src/test/java/org/greencodeinitiative/creedengo/python/checks/AvoidConvBiasBeforeBatchNormTest.java b/src/test/java/org/greencodeinitiative/creedengo/python/checks/AvoidConvBiasBeforeBatchNormTest.java
new file mode 100644
index 0000000..39e9452
--- /dev/null
+++ b/src/test/java/org/greencodeinitiative/creedengo/python/checks/AvoidConvBiasBeforeBatchNormTest.java
@@ -0,0 +1,29 @@
+/*
+ * creedengo - Python language - Provides rules to reduce the environmental footprint of your Python programs
+ * Copyright © 2024 Green Code Initiative (https://green-code-initiative.org)
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+package org.greencodeinitiative.creedengo.python.checks;
+
+import org.junit.Test;
+import org.sonar.python.checks.utils.PythonCheckVerifier;
+
+public class AvoidConvBiasBeforeBatchNormTest {
+
+ @Test
+ public void test() {
+ PythonCheckVerifier.verify("src/test/resources/checks/avoidConvBiasBeforeBatchNorm.py", new AvoidConvBiasBeforeBatchNorm());
+ }
+}
diff --git a/src/test/resources/checks/avoidConvBiasBeforeBatchNorm.py b/src/test/resources/checks/avoidConvBiasBeforeBatchNorm.py
new file mode 100644
index 0000000..2254e97
--- /dev/null
+++ b/src/test/resources/checks/avoidConvBiasBeforeBatchNorm.py
@@ -0,0 +1,238 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+#input = torch.rand(2, 1, 320, 111)
+
+
+class RandomClass(Object):
+ def __init__(self, a):
+ self.a = a
+ def forward(self, x):
+ return self.a + x
+
+class WeirdModelWithoutForward(nn.Module):
+ def __init__(self, a):
+ self.a = a
+
+class WeirdModelWithoutInit(nn.Module):
+
+ def forward(x):
+ return x
+
+class NetWithConvBiasSetToTrueWithARandomChange(nn.Module):
+ def __init__(self):
+ super(NetWithConvBiasSetToTrueWithARandomChange, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels=1, out_channels=10,
+ kernel_size=5,
+ stride=1)
+ self.conv2 = nn.Conv2d(10, 20, kernel_size=5, bias=True)
+ self.conv2_bn = nn.BatchNorm2d(20)
+ self.dense1 = nn.Linear(in_features=320, out_features=50)
+ self.dense1_bn = nn.BatchNorm1d(50)
+ self.dense2 = nn.Linear(50, 10)
+
+ def forward(self, x):
+ x = F.relu(F.max_pool2d(self.conv1(x), 2))
+ x = self.conv2(x)
+ x = x / 2
+ x = self.conv2_bn(x)
+ x = F.relu(F.max_pool2d(x, 2))
+ x = x.view(-1, 320)
+ x = F.relu(self.dense1_bn(self.dense1(x)))
+ return F.relu(self.dense2(x))
+
+
+class NetWithConvBiasSetToTrueWithARandomAddedLineBetweenConvAndBN(nn.Module):
+ def __init__(self):
+ super(NetWithConvBiasSetToTrueWithARandomAddedLineBetweenConvAndBN, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels=1, out_channels=10,
+ kernel_size=5,
+ stride=1, bias=False)
+ self.conv2 = nn.Conv2d(10, 20, kernel_size=5, bias=True) # Noncompliant {{Remove bias for convolutions before batch norm layers to save time and memory.}}
+ self.conv2_bn = nn.BatchNorm2d(20)
+ self.dense1 = nn.Linear(in_features=320, out_features=50)
+ self.dense1_bn = nn.BatchNorm1d(50)
+ self.dense2 = nn.Linear(50, 10)
+ self.idx = 0
+ def forward(self, x):
+ x = F.relu(F.max_pool2d(self.conv1(x), 2))
+ x = self.conv2(x)
+ self.idx += 1
+ x = self.conv2_bn(x)
+ x = F.relu(F.max_pool2d(x, 2))
+ x = x.view(-1, 320)
+ x = F.relu(self.dense1_bn(self.dense1(x)))
+ return F.relu(self.dense2(x))
+
+class NetWithConvBiasSetToTrueWithDiffVariableName(nn.Module):
+ def __init__(self):
+ super(NetWithConvBiasSetToTrueWithDiffVariableName, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels=1, out_channels=10,
+ kernel_size=5,
+ stride=1, bias=False)
+ self.conv2 = nn.Conv2d(10, 20, kernel_size=5, bias=True) # Noncompliant {{Remove bias for convolutions before batch norm layers to save time and memory.}}
+ self.conv2_bn = nn.BatchNorm2d(20)
+ self.dense1 = nn.Linear(in_features=320, out_features=50)
+ self.dense1_bn = nn.BatchNorm1d(50)
+ self.dense2 = nn.Linear(50, 10)
+ self.idx = 0
+ def forward(self, x):
+ x1 = F.relu(F.max_pool2d(self.conv1(x), 2))
+ x2 = self.conv2(x1)
+ self.idx += 1
+ x3 = self.conv2_bn(x2)
+ x4 = F.relu(F.max_pool2d(x3, 2))
+ x5 = x4.view(-1, 320)
+ x6 = F.relu(self.dense1_bn(self.dense1(x5)))
+ return F.relu(self.dense2(x6))
+
+class CompNetWithConvBiasSetToTrueWithDiffVariableName(nn.Module):
+ def __init__(self):
+ super(CompNetWithConvBiasSetToTrueWithDiffVariableName, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels=1, out_channels=10,
+ kernel_size=5,
+ stride=1, bias=False)
+ self.conv2 = nn.Conv2d(10, 20, kernel_size=5, bias=False)
+ self.conv2_bn = nn.BatchNorm2d(20)
+ self.dense1 = nn.Linear(in_features=320, out_features=50)
+ self.dense1_bn = nn.BatchNorm1d(50)
+ self.dense2 = nn.Linear(50, 10)
+ self.idx = 0
+ def forward(self, x):
+ x1 = F.relu(F.max_pool2d(self.conv1(x), 2))
+ x2 = self.conv2(x1)
+ self.idx += 1
+ x3 = self.conv2_bn(x2)
+ x4 = F.relu(F.max_pool2d(x3, 2))
+ x5 = x4.view(-1, 320)
+ x6 = F.relu(self.dense1_bn(self.dense1(x5)))
+ return F.relu(self.dense2(x6))
+
+class NetWithConvBiasSetToTrue(nn.Module):
+ def __init__(self):
+ super(NetWithConvBiasSetToTrue, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels=1, out_channels=10,
+ kernel_size=5,
+ stride=1)
+ self.conv2 = nn.Conv2d(10, 20, kernel_size=5, bias=True) # Noncompliant {{Remove bias for convolutions before batch norm layers to save time and memory.}}
+ self.conv2_bn = nn.BatchNorm2d(20)
+ self.dense1 = nn.Linear(in_features=320, out_features=50)
+ self.dense1_bn = nn.BatchNorm1d(50)
+ self.dense2 = nn.Linear(50, 10)
+
+ def forward(self, x):
+ x = F.relu(F.max_pool2d(self.conv1(x), 2))
+ x = self.conv2(x)
+ x = self.conv2_bn(x)
+ x = F.relu(F.max_pool2d(x, 2))
+ x = x.view(-1, 320)
+ x = F.relu(self.dense1_bn(self.dense1(x)))
+ return F.relu(self.dense2(x))
+
+class NetWithDefaultConvBias(nn.Module):
+ def __init__(self):
+ super(NetWithDefaultConvBias, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels=1, out_channels=10,
+ kernel_size=5,
+ stride=1)
+ self.conv2 = nn.Conv2d(10, 20, kernel_size=5, bias=True) # Noncompliant {{Remove bias for convolutions before batch norm layers to save time and memory.}}
+ self.conv2_bn = nn.BatchNorm2d(20)
+ self.dense1 = nn.Linear(in_features=320, out_features=50)
+ self.dense1_bn = nn.BatchNorm1d(50)
+ self.dense2 = nn.Linear(50, 10)
+
+ def forward(self, x):
+ x = F.relu(F.max_pool2d(self.conv1(x), 2))
+ x = F.relu(F.max_pool2d(self.conv2_bn(self.conv2(x)), 2))
+ x = x.view(-1, 320)
+ x = F.relu(self.dense1_bn(self.dense1(x)))
+ return F.relu(self.dense2(x))
+
+class NonCompliantNetWithSequentialKeywordParam(nn.Module):
+ def __init__(self):
+ super(NonCompliantNetWithSequentialKeywordParam, self).__init__()
+ self.encoder = nn.Sequential(
+ nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5, stride=1),
+ nn.MaxPool2d(2),
+ nn.ReLU(),
+ nn.Conv2d(10, 20, kernel_size=5, bias=True), # Noncompliant {{Remove bias for convolutions before batch norm layers to save time and memory.}}
+ nn.BatchNorm2d(20),
+ nn.MaxPool2d(2),
+ nn.ReLU()
+ )
+ self.dense1 = nn.Linear(in_features=320, out_features=50)
+ self.dense1_bn = nn.BatchNorm1d(50)
+ self.dense2 = nn.Linear(50, 10)
+ def forward(self, x):
+ x = self.encoder(x)
+ x = x.view(-1, 320)
+ x = F.relu(self.dense1_bn(self.dense1(x)))
+ return F.relu(self.dense2(x))
+
+
+class NonCompliantNetWithSequentialPosParam(nn.Module):
+ def __init__(self):
+ super(NonCompliantNetWithSequentialPosParam, self).__init__()
+ self.encoder = nn.Sequential(
+ nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5, stride=1, bias=False),
+ nn.MaxPool2d(2),
+ nn.ReLU(),
+ nn.Conv2d(10, 20, kernel_size=5), # Noncompliant {{Remove bias for convolutions before batch norm layers to save time and memory.}}
+ nn.BatchNorm2d(20),
+ nn.MaxPool2d(2),
+ nn.ReLU()
+ )
+ self.dense1 = nn.Linear(in_features=320, out_features=50)
+ self.dense1_bn = nn.BatchNorm1d(50)
+ self.dense2 = nn.Linear(50, 10)
+ def forward(self, x):
+ x = self.encoder(x)
+ x = x.view(-1, 320)
+ x = F.relu(self.dense1_bn(self.dense1(x)))
+ return F.relu(self.dense2(x))
+
+
+
+class CompliantNet(nn.Module):
+ def __init__(self):
+ super(CompliantNet, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels=1, out_channels=10,
+ kernel_size=5,
+ stride=1)
+ self.conv2 = nn.Conv2d(10, 20, kernel_size=5, bias=False)
+ self.conv2_bn = nn.BatchNorm2d(20)
+ self.dense1 = nn.Linear(in_features=320, out_features=50)
+ self.dense1_bn = nn.BatchNorm1d(50)
+ self.dense2 = nn.Linear(50, 10)
+
+ def forward(self, x):
+ x = F.relu(F.max_pool2d(self.conv1(x), 2))
+ x = F.relu(F.max_pool2d(self.conv2_bn(self.conv2(x)), 2))
+ x = x.view(-1, 320)
+ x = F.relu(self.dense1_bn(self.dense1(x)))
+ return F.relu(self.dense2(x))
+
+
+class CompliantNetWithSequential(nn.Module):
+ def __init__(self):
+ super(CompliantNetWithSequential, self).__init__()
+ self.encoder = nn.Sequential(
+ nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5, stride=1),
+ nn.MaxPool2d(2),
+ nn.ReLU(),
+ nn.Conv2d(10, 20, kernel_size=5, bias=False),
+ nn.BatchNorm2d(20),
+ nn.MaxPool2d(2),
+ nn.ReLU()
+ )
+ self.dense1 = nn.Linear(in_features=320, out_features=50)
+ self.dense1_bn = nn.BatchNorm1d(50)
+ self.dense2 = nn.Linear(50, 10)
+ def forward(self, x):
+ x = self.encoder(x)
+ x = x.view(-1, 320)
+ x = F.relu(self.dense1_bn(self.dense1(x)))
+ return F.relu(self.dense2(x))
+
+
+