diff --git a/CHANGELOG.md b/CHANGELOG.md index f97f08d..b868024 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Add rule GCI101 Avoid Conv Bias Before Batch Normalization, a rule specific to Deeplearning +- Add an utils file to keep all utils functions in one place + ### Changed - compatibility updates for SonarQube 25.5.0 diff --git a/src/main/java/org/greencodeinitiative/creedengo/python/PythonRuleRepository.java b/src/main/java/org/greencodeinitiative/creedengo/python/PythonRuleRepository.java index c385979..e3c7520 100644 --- a/src/main/java/org/greencodeinitiative/creedengo/python/PythonRuleRepository.java +++ b/src/main/java/org/greencodeinitiative/creedengo/python/PythonRuleRepository.java @@ -40,7 +40,8 @@ public class PythonRuleRepository implements RulesDefinition, PythonCustomRuleRe AvoidFullSQLRequest.class, AvoidListComprehensionInIterations.class, DetectUnoptimizedImageFormat.class, - AvoidMultipleIfElseStatementCheck.class + AvoidMultipleIfElseStatementCheck.class, + AvoidConvBiasBeforeBatchNorm.class ); public static final String LANGUAGE = "py"; diff --git a/src/main/java/org/greencodeinitiative/creedengo/python/checks/AvoidConvBiasBeforeBatchNorm.java b/src/main/java/org/greencodeinitiative/creedengo/python/checks/AvoidConvBiasBeforeBatchNorm.java new file mode 100644 index 0000000..e9a3843 --- /dev/null +++ b/src/main/java/org/greencodeinitiative/creedengo/python/checks/AvoidConvBiasBeforeBatchNorm.java @@ -0,0 +1,208 @@ +/* + * creedengo - Python language - Provides rules to reduce the environmental footprint of your Python programs + * Copyright © 2024 Green Code Initiative (https://green-code-initiative.org) + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.greencodeinitiative.creedengo.python.checks; + + +import org.sonar.check.Priority; +import org.sonar.check.Rule; +import org.sonar.plugins.python.api.PythonSubscriptionCheck; +import org.sonar.plugins.python.api.SubscriptionContext; +import org.sonar.plugins.python.api.symbols.ClassSymbol; +import org.sonar.plugins.python.api.tree.Name; +import org.sonar.plugins.python.api.tree.ClassDef; +import org.sonar.plugins.python.api.tree.CallExpression; +import org.sonar.plugins.python.api.tree.Tree; +import org.sonar.plugins.python.api.tree.Expression; +import org.sonar.plugins.python.api.tree.RegularArgument; +import org.sonar.plugins.python.api.tree.FunctionDef; +import org.sonar.plugins.python.api.tree.Argument; +import org.sonar.plugins.python.api.tree.AssignmentStatement; +import org.sonar.plugins.python.api.tree.BaseTreeVisitor; +import org.sonar.plugins.python.api.tree.QualifiedExpression; +import org.sonar.plugins.python.api.tree.Statement; + +import java.util.ArrayList; +import java.util.Map; +import java.util.HashMap; +import java.util.Objects; +import java.util.Optional; + + + +import static org.sonar.plugins.python.api.tree.Tree.Kind.ASSIGNMENT_STMT; +import static org.sonar.plugins.python.api.tree.Tree.Kind.CALL_EXPR; +import static org.sonar.plugins.python.api.tree.Tree.Kind.NAME; +import static org.sonar.plugins.python.api.tree.Tree.Kind.REGULAR_ARGUMENT; +import static org.sonar.plugins.python.api.tree.Tree.Kind.FUNCDEF; + +@Rule(key="GCI101") + +public class AvoidConvBiasBeforeBatchNorm extends PythonSubscriptionCheck { + + public static final String RULE_KEY = "P4"; + private static final String nnModuleFullyQualifiedName = "torch.nn.Module"; + private static final String convFullyQualifiedName = "torch.nn.Conv2d"; + private static final String forwardMethodName = "forward"; + private static final String batchNormFullyQualifiedName = "torch.nn.BatchNorm2d"; + private static final String sequentialModuleFullyQualifiedName = "torch.nn.Sequential"; + protected static final String MESSAGE = "Remove bias for convolutions before batch norm layers to save time and memory."; + + @Override + public void initialize(Context context) { + context.registerSyntaxNodeConsumer(Tree.Kind.CLASSDEF, ctx -> { + ClassDef classDef = (ClassDef) ctx.syntaxNode(); + Optional.ofNullable(classDef).filter(this::isModelClass).ifPresent(e -> visitModelClass(ctx, e)); + }); + } + + private boolean isConvWithBias(CallExpression convDefinition) { + RegularArgument biasArgument = Utils.nthArgumentOrKeyword(7, "bias", convDefinition.arguments()); + if (biasArgument == null) + return true; + else { + Expression expression = biasArgument.expression(); + return expression.is(NAME) && ((Name) expression).name().equals("True"); + } + } + + private boolean isModelClass(ClassDef classDef) { + ClassSymbol classSymbol = (ClassSymbol) classDef.name().symbol(); + if (classSymbol != null) { + return classSymbol.superClasses().stream().anyMatch(e -> Objects.equals(e.fullyQualifiedName(), nnModuleFullyQualifiedName)) + && classSymbol.declaredMembers().stream().anyMatch(e -> e.name().equals(forwardMethodName)); + } else + return false; + } + + private void reportIfBatchNormIsCalledAfterDirtyConv(SubscriptionContext context, FunctionDef forwardDef, Map dirtyConvInInit, + Map batchNormsInInit) { + ForwardMethodVisitor visitor = new ForwardMethodVisitor(); + forwardDef.accept(visitor); + + for (CallExpression callInForward : visitor.callExpressions) { + // if it is a batchNorm + if (batchNormsInInit.containsKey(Utils.getMethodName(callInForward))) { + int batchNormLineNo = callInForward.firstToken().line(); + for (Argument batchNormArgument : Utils.getArgumentsFromCall(callInForward)) { + Expression batchNormArgumentExpression = ((RegularArgument) batchNormArgument).expression(); + if (batchNormArgumentExpression.is(CALL_EXPR)) { + String functionName = Utils.getMethodName((CallExpression) batchNormArgumentExpression); + if (dirtyConvInInit.containsKey(functionName)) { + context.addIssue(dirtyConvInInit.get(functionName), MESSAGE); + } + + // if it uses a variable + } else if (batchNormArgumentExpression.is(NAME) && ((Name) batchNormArgumentExpression).isVariable()) { + String batchNormArgumentName = ((Name) batchNormArgumentExpression).name(); + + // loop through all call expressions in forward + AssignmentStatement lastAssignmentStatementBeforeBatchNorm = null; + + for (AssignmentStatement assignmentStatement : visitor.assignmentStatements) { + Name variable = (Name) assignmentStatement.lhsExpressions().get(0).expressions().get(0); + String variableName = variable.name(); + if (assignmentStatement.firstToken().line() >= batchNormLineNo) + break; + + if (variableName.equals(batchNormArgumentName)) + lastAssignmentStatementBeforeBatchNorm = assignmentStatement; + } + if (lastAssignmentStatementBeforeBatchNorm != null && lastAssignmentStatementBeforeBatchNorm.assignedValue().is(CALL_EXPR)) { + CallExpression function = (CallExpression) lastAssignmentStatementBeforeBatchNorm.assignedValue(); + String functionName = Utils.getMethodName(function); + if (dirtyConvInInit.containsKey(functionName)) { + context.addIssue(dirtyConvInInit.get(functionName), MESSAGE); + } + } + } + } + } + } + } + + private void reportForSequentialModules(SubscriptionContext context, CallExpression sequentialCall) { + int moduleIndex = 0; + int nModulesInSequential = Utils.getArgumentsFromCall(sequentialCall).size(); + while (moduleIndex < nModulesInSequential) { + Argument moduleInSequential = Utils.getArgumentsFromCall(sequentialCall).get(moduleIndex); + if (moduleInSequential.is(REGULAR_ARGUMENT) && ((RegularArgument) moduleInSequential).expression().is(CALL_EXPR)) { + CallExpression module = (CallExpression) ((RegularArgument) moduleInSequential).expression(); + if (Utils.getQualifiedName(module).equals(convFullyQualifiedName) && isConvWithBias(module)) { + if (moduleIndex == nModulesInSequential - 1) + break; + Argument nextModuleInSequential = Utils.getArgumentsFromCall(sequentialCall).get(moduleIndex + 1); + CallExpression nextModule = (CallExpression) ((RegularArgument) nextModuleInSequential).expression(); + if (Utils.getQualifiedName(nextModule).equals(batchNormFullyQualifiedName)) + context.addIssue(module, MESSAGE); + } + } + moduleIndex += 1; + } + } + + private void visitModelClass(SubscriptionContext context, ClassDef classDef) { + Map dirtyConvInInit = new HashMap<>(); + Map batchNormsInInit = new HashMap<>(); + + for (Statement s : classDef.body().statements()) { + if (s.is(FUNCDEF) && ((FunctionDef) s).name().name().equals("__init__")) { + for (Statement ss : ((FunctionDef) s).body().statements()) { + if (ss.is(ASSIGNMENT_STMT)) { + Expression lhs = ((AssignmentStatement) ss).lhsExpressions().get(0).expressions().get(0); + // consider only calls (modules) + if (!((AssignmentStatement) ss).assignedValue().is(CALL_EXPR)) + break; + CallExpression callExpression = (CallExpression) ((AssignmentStatement) ss).assignedValue(); + String variableName = ((QualifiedExpression) lhs).name().name(); + String variableClass = Utils.getQualifiedName(callExpression); + if (variableClass.equals(sequentialModuleFullyQualifiedName)) { + reportForSequentialModules(context, callExpression); + } else if (convFullyQualifiedName.contains(variableClass) && isConvWithBias(callExpression)) { + dirtyConvInInit.put(variableName, callExpression); + } else if (batchNormFullyQualifiedName.contains(variableClass)) { + batchNormsInInit.put(variableName, callExpression); + } + } + } + } + } + for (Statement s : classDef.body().statements()) { + if (s.is(FUNCDEF) && ((FunctionDef) s).name().name().equals(forwardMethodName)) { + FunctionDef forwardDef = (FunctionDef) s; + reportIfBatchNormIsCalledAfterDirtyConv(context, forwardDef, dirtyConvInInit, batchNormsInInit); + } + } + + } + + private static class ForwardMethodVisitor extends BaseTreeVisitor { + private final ArrayList callExpressions = new ArrayList<>(); + private final ArrayList assignmentStatements = new ArrayList<>(); + + @Override + public void visitCallExpression(CallExpression pyCallExpressionTree) { + callExpressions.add(pyCallExpressionTree); + super.visitCallExpression(pyCallExpressionTree); + } + + public void visitAssignmentStatement(AssignmentStatement pyAssignmentStatementTree) { + assignmentStatements.add(pyAssignmentStatementTree); + super.visitAssignmentStatement(pyAssignmentStatementTree); + } + } +} \ No newline at end of file diff --git a/src/main/java/org/greencodeinitiative/creedengo/python/checks/Utils.java b/src/main/java/org/greencodeinitiative/creedengo/python/checks/Utils.java new file mode 100644 index 0000000..44f2183 --- /dev/null +++ b/src/main/java/org/greencodeinitiative/creedengo/python/checks/Utils.java @@ -0,0 +1,102 @@ +/* + * creedengo - Python language - Provides rules to reduce the environmental footprint of your Python programs + * Copyright © 2024 Green Code Initiative (https://green-code-initiative.org) + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.greencodeinitiative.creedengo.python.checks; + +import org.sonar.plugins.python.api.SubscriptionContext; +import org.sonar.plugins.python.api.symbols.Symbol; +import org.sonar.plugins.python.api.tree.Argument; +import org.sonar.plugins.python.api.tree.AssignmentStatement; +import org.sonar.plugins.python.api.tree.Tree; +import org.sonar.plugins.python.api.tree.RegularArgument; +import org.sonar.plugins.python.api.tree.Name; +import org.sonar.plugins.python.api.tree.CallExpression; +import org.sonar.plugins.python.api.tree.Expression; + +import javax.annotation.CheckForNull; +import java.util.List; +import java.util.Objects; + +public class Utils { + + private static boolean hasKeyword(Argument argument, String keyword) { + if (!argument.is(new Tree.Kind[] {Tree.Kind.REGULAR_ARGUMENT})) { + return false; + } else { + Name keywordArgument = ((RegularArgument) argument).keywordArgument(); + return keywordArgument != null && keywordArgument.name().equals(keyword); + } + } + + @CheckForNull + public static RegularArgument nthArgumentOrKeyword(int argPosition, String keyword, List arguments) { + for (int i = 0; i < arguments.size(); ++i) { + Argument argument = (Argument) arguments.get(i); + if (hasKeyword(argument, keyword)) { + return (RegularArgument) argument; + } + + if (argument.is(new Tree.Kind[] {Tree.Kind.REGULAR_ARGUMENT})) { + RegularArgument regularArgument = (RegularArgument) argument; + if (regularArgument.keywordArgument() == null && argPosition == i) { + return regularArgument; + } + } + } + + return null; + } + + public static String getQualifiedName(CallExpression callExpression) { + Symbol symbol = callExpression.calleeSymbol(); + + return symbol != null && symbol.fullyQualifiedName() != null ? symbol.fullyQualifiedName() : ""; + } + + public static String getMethodName(CallExpression callExpression) { + Symbol symbol = callExpression.calleeSymbol(); + return symbol != null && symbol.name() != null ? symbol.name() : ""; + } + + public static List getArgumentsFromCall(CallExpression callExpression) { + try { + return Objects.requireNonNull(callExpression.argumentList()).arguments(); + } catch (NullPointerException e) { + return List.of(); + } + } + + public static String getVariableName(SubscriptionContext context) { + Tree node = context.syntaxNode(); + Tree current = node; + while (current != null && !current.is(Tree.Kind.ASSIGNMENT_STMT)) { + current = current.parent(); + } + if (current != null && current.is(Tree.Kind.ASSIGNMENT_STMT)) { + AssignmentStatement assignment = (AssignmentStatement) current; + if (!assignment.lhsExpressions().isEmpty() && !assignment.lhsExpressions().get(0).expressions().isEmpty()) { + Expression leftExpr = assignment.lhsExpressions().get(0).expressions().get(0); + if (leftExpr.is(Tree.Kind.NAME)) { + Name variableName = (Name) leftExpr; + return variableName.name(); + } + } + + } + return null; + } +} diff --git a/src/test/java/org/greencodeinitiative/creedengo/python/checks/AvoidConvBiasBeforeBatchNormTest.java b/src/test/java/org/greencodeinitiative/creedengo/python/checks/AvoidConvBiasBeforeBatchNormTest.java new file mode 100644 index 0000000..39e9452 --- /dev/null +++ b/src/test/java/org/greencodeinitiative/creedengo/python/checks/AvoidConvBiasBeforeBatchNormTest.java @@ -0,0 +1,29 @@ +/* + * creedengo - Python language - Provides rules to reduce the environmental footprint of your Python programs + * Copyright © 2024 Green Code Initiative (https://green-code-initiative.org) + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.greencodeinitiative.creedengo.python.checks; + +import org.junit.Test; +import org.sonar.python.checks.utils.PythonCheckVerifier; + +public class AvoidConvBiasBeforeBatchNormTest { + + @Test + public void test() { + PythonCheckVerifier.verify("src/test/resources/checks/avoidConvBiasBeforeBatchNorm.py", new AvoidConvBiasBeforeBatchNorm()); + } +} diff --git a/src/test/resources/checks/avoidConvBiasBeforeBatchNorm.py b/src/test/resources/checks/avoidConvBiasBeforeBatchNorm.py new file mode 100644 index 0000000..2254e97 --- /dev/null +++ b/src/test/resources/checks/avoidConvBiasBeforeBatchNorm.py @@ -0,0 +1,238 @@ +import torch.nn as nn +import torch.nn.functional as F + +#input = torch.rand(2, 1, 320, 111) + + +class RandomClass(Object): + def __init__(self, a): + self.a = a + def forward(self, x): + return self.a + x + +class WeirdModelWithoutForward(nn.Module): + def __init__(self, a): + self.a = a + +class WeirdModelWithoutInit(nn.Module): + + def forward(x): + return x + +class NetWithConvBiasSetToTrueWithARandomChange(nn.Module): + def __init__(self): + super(NetWithConvBiasSetToTrueWithARandomChange, self).__init__() + self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, + kernel_size=5, + stride=1) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5, bias=True) + self.conv2_bn = nn.BatchNorm2d(20) + self.dense1 = nn.Linear(in_features=320, out_features=50) + self.dense1_bn = nn.BatchNorm1d(50) + self.dense2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = self.conv2(x) + x = x / 2 + x = self.conv2_bn(x) + x = F.relu(F.max_pool2d(x, 2)) + x = x.view(-1, 320) + x = F.relu(self.dense1_bn(self.dense1(x))) + return F.relu(self.dense2(x)) + + +class NetWithConvBiasSetToTrueWithARandomAddedLineBetweenConvAndBN(nn.Module): + def __init__(self): + super(NetWithConvBiasSetToTrueWithARandomAddedLineBetweenConvAndBN, self).__init__() + self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, + kernel_size=5, + stride=1, bias=False) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5, bias=True) # Noncompliant {{Remove bias for convolutions before batch norm layers to save time and memory.}} + self.conv2_bn = nn.BatchNorm2d(20) + self.dense1 = nn.Linear(in_features=320, out_features=50) + self.dense1_bn = nn.BatchNorm1d(50) + self.dense2 = nn.Linear(50, 10) + self.idx = 0 + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = self.conv2(x) + self.idx += 1 + x = self.conv2_bn(x) + x = F.relu(F.max_pool2d(x, 2)) + x = x.view(-1, 320) + x = F.relu(self.dense1_bn(self.dense1(x))) + return F.relu(self.dense2(x)) + +class NetWithConvBiasSetToTrueWithDiffVariableName(nn.Module): + def __init__(self): + super(NetWithConvBiasSetToTrueWithDiffVariableName, self).__init__() + self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, + kernel_size=5, + stride=1, bias=False) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5, bias=True) # Noncompliant {{Remove bias for convolutions before batch norm layers to save time and memory.}} + self.conv2_bn = nn.BatchNorm2d(20) + self.dense1 = nn.Linear(in_features=320, out_features=50) + self.dense1_bn = nn.BatchNorm1d(50) + self.dense2 = nn.Linear(50, 10) + self.idx = 0 + def forward(self, x): + x1 = F.relu(F.max_pool2d(self.conv1(x), 2)) + x2 = self.conv2(x1) + self.idx += 1 + x3 = self.conv2_bn(x2) + x4 = F.relu(F.max_pool2d(x3, 2)) + x5 = x4.view(-1, 320) + x6 = F.relu(self.dense1_bn(self.dense1(x5))) + return F.relu(self.dense2(x6)) + +class CompNetWithConvBiasSetToTrueWithDiffVariableName(nn.Module): + def __init__(self): + super(CompNetWithConvBiasSetToTrueWithDiffVariableName, self).__init__() + self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, + kernel_size=5, + stride=1, bias=False) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5, bias=False) + self.conv2_bn = nn.BatchNorm2d(20) + self.dense1 = nn.Linear(in_features=320, out_features=50) + self.dense1_bn = nn.BatchNorm1d(50) + self.dense2 = nn.Linear(50, 10) + self.idx = 0 + def forward(self, x): + x1 = F.relu(F.max_pool2d(self.conv1(x), 2)) + x2 = self.conv2(x1) + self.idx += 1 + x3 = self.conv2_bn(x2) + x4 = F.relu(F.max_pool2d(x3, 2)) + x5 = x4.view(-1, 320) + x6 = F.relu(self.dense1_bn(self.dense1(x5))) + return F.relu(self.dense2(x6)) + +class NetWithConvBiasSetToTrue(nn.Module): + def __init__(self): + super(NetWithConvBiasSetToTrue, self).__init__() + self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, + kernel_size=5, + stride=1) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5, bias=True) # Noncompliant {{Remove bias for convolutions before batch norm layers to save time and memory.}} + self.conv2_bn = nn.BatchNorm2d(20) + self.dense1 = nn.Linear(in_features=320, out_features=50) + self.dense1_bn = nn.BatchNorm1d(50) + self.dense2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = self.conv2(x) + x = self.conv2_bn(x) + x = F.relu(F.max_pool2d(x, 2)) + x = x.view(-1, 320) + x = F.relu(self.dense1_bn(self.dense1(x))) + return F.relu(self.dense2(x)) + +class NetWithDefaultConvBias(nn.Module): + def __init__(self): + super(NetWithDefaultConvBias, self).__init__() + self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, + kernel_size=5, + stride=1) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5, bias=True) # Noncompliant {{Remove bias for convolutions before batch norm layers to save time and memory.}} + self.conv2_bn = nn.BatchNorm2d(20) + self.dense1 = nn.Linear(in_features=320, out_features=50) + self.dense1_bn = nn.BatchNorm1d(50) + self.dense2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_bn(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.dense1_bn(self.dense1(x))) + return F.relu(self.dense2(x)) + +class NonCompliantNetWithSequentialKeywordParam(nn.Module): + def __init__(self): + super(NonCompliantNetWithSequentialKeywordParam, self).__init__() + self.encoder = nn.Sequential( + nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5, stride=1), + nn.MaxPool2d(2), + nn.ReLU(), + nn.Conv2d(10, 20, kernel_size=5, bias=True), # Noncompliant {{Remove bias for convolutions before batch norm layers to save time and memory.}} + nn.BatchNorm2d(20), + nn.MaxPool2d(2), + nn.ReLU() + ) + self.dense1 = nn.Linear(in_features=320, out_features=50) + self.dense1_bn = nn.BatchNorm1d(50) + self.dense2 = nn.Linear(50, 10) + def forward(self, x): + x = self.encoder(x) + x = x.view(-1, 320) + x = F.relu(self.dense1_bn(self.dense1(x))) + return F.relu(self.dense2(x)) + + +class NonCompliantNetWithSequentialPosParam(nn.Module): + def __init__(self): + super(NonCompliantNetWithSequentialPosParam, self).__init__() + self.encoder = nn.Sequential( + nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5, stride=1, bias=False), + nn.MaxPool2d(2), + nn.ReLU(), + nn.Conv2d(10, 20, kernel_size=5), # Noncompliant {{Remove bias for convolutions before batch norm layers to save time and memory.}} + nn.BatchNorm2d(20), + nn.MaxPool2d(2), + nn.ReLU() + ) + self.dense1 = nn.Linear(in_features=320, out_features=50) + self.dense1_bn = nn.BatchNorm1d(50) + self.dense2 = nn.Linear(50, 10) + def forward(self, x): + x = self.encoder(x) + x = x.view(-1, 320) + x = F.relu(self.dense1_bn(self.dense1(x))) + return F.relu(self.dense2(x)) + + + +class CompliantNet(nn.Module): + def __init__(self): + super(CompliantNet, self).__init__() + self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, + kernel_size=5, + stride=1) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5, bias=False) + self.conv2_bn = nn.BatchNorm2d(20) + self.dense1 = nn.Linear(in_features=320, out_features=50) + self.dense1_bn = nn.BatchNorm1d(50) + self.dense2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_bn(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.dense1_bn(self.dense1(x))) + return F.relu(self.dense2(x)) + + +class CompliantNetWithSequential(nn.Module): + def __init__(self): + super(CompliantNetWithSequential, self).__init__() + self.encoder = nn.Sequential( + nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5, stride=1), + nn.MaxPool2d(2), + nn.ReLU(), + nn.Conv2d(10, 20, kernel_size=5, bias=False), + nn.BatchNorm2d(20), + nn.MaxPool2d(2), + nn.ReLU() + ) + self.dense1 = nn.Linear(in_features=320, out_features=50) + self.dense1_bn = nn.BatchNorm1d(50) + self.dense2 = nn.Linear(50, 10) + def forward(self, x): + x = self.encoder(x) + x = x.view(-1, 320) + x = F.relu(self.dense1_bn(self.dense1(x))) + return F.relu(self.dense2(x)) + + +