diff --git a/setup.cfg b/setup.cfg index 505e4e33..0c64e05d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,6 +54,7 @@ install_requires = onnxruntime>=1.16.1 sigtools>=4.0.1 toposort>=1.7.0 + onnxscript==0.2.6 [options.packages.find] diff --git a/src/qonnx/transformation/batchnorm_to_affine.py b/src/qonnx/transformation/batchnorm_to_affine.py index c89d2bdc..7c9f89a9 100644 --- a/src/qonnx/transformation/batchnorm_to_affine.py +++ b/src/qonnx/transformation/batchnorm_to_affine.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 Xilinx, Inc. +# Copyright (c) 2025 Advanced Micro Devices, Inc. # All rights reserved. # # Redistribution and use in source and binary forms, with or without @@ -11,7 +11,7 @@ # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # -# * Neither the name of Xilinx nor the names of its +# * Neither the name of AMD nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # @@ -27,78 +27,59 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import numpy as np -from onnx import TensorProto +from onnx import TensorProto, helper from onnx import helper as oh from qonnx.transformation.base import Transformation from qonnx.transformation.infer_shapes import InferShapes +from qonnx.transformation.fold_constants import FoldConstants from qonnx.util.basic import get_by_name +from qonnx.core.modelwrapper import ModelWrapper +from onnxscript import opset15 as op +from onnxscript import script +from onnxscript.rewriter import pattern, rewrite +from onnxscript import ir + +from qonnx.util.onnxscript import ReplacePattern + +def target_pattern(op, x, scale, bias, mean, var): + return op.BatchNormalization(x, scale, bias, mean, var) + +def replace_pattern(op, x, scale, bias, mean, var, **kwargs): + + # Get epsilon from matched pattern + batch_norm = kwargs['match'].nodes[0] + epsilon_attr = batch_norm.attributes.get('epsilon', None) + epsilon_value = 1e-5 if epsilon_attr is None else epsilon_attr.value + epsilon_tensor = helper.make_tensor("epsilon", TensorProto.FLOAT, (1,), [epsilon_value]) + epsilon = op.Constant(value=epsilon_tensor) + + A = op.Div(scale, op.Sqrt(op.Add(var, epsilon))) + B = op.Sub(bias, op.Mul(A, mean)) + + # Unsqueeze A and B + input_shape = x.shape + assert input_shape is not None and len(input_shape) >= 2 + n_spatial_dims = len(input_shape) - 2 + axes = [0] + [i + 2 for i in range(n_spatial_dims)] + A = op.Unsqueeze(A, axes=axes) + B = op.Unsqueeze(B, axes=axes) + + return op.Add(op.Mul(x, A), B) + +rule1 = pattern.RewriteRule(target_pattern, ReplacePattern(replace_pattern), verbose=10) +rewrite_rules = pattern.RewriteRuleSet([rule1]) class BatchNormToAffine(Transformation): """Replaces any test-time BatchNorm layers with Mul-Add layers.""" def apply(self, model): - graph = model.graph - node_ind = 0 - graph_modified = False - for n in graph.node: - node_ind += 1 - if n.op_type == "BatchNormalization": - graph_modified = True - bn_input = n.input[0] - bn_output = n.output[0] - # extract batchnorm parameters as numpy arrays - scale = model.get_initializer(n.input[1]) - bias = model.get_initializer(n.input[2]) - mean = model.get_initializer(n.input[3]) - variance = model.get_initializer(n.input[4]) - epsilon = get_by_name(n.attribute, "epsilon") - epsilon = getattr(epsilon, "f", 1e-5) - # find A and B to compute batchnorm as affine transpose Ax+B - # TODO is a division by moving avg factor needed for variance? - A = scale / np.sqrt(epsilon + variance) - B = bias - (A * mean) - # see if we have surrounding Unsqueeze/Squeeze nodes we can remove - producer = model.find_producer(bn_input) - if producer is not None: - if producer.op_type == "Unsqueeze": - bn_input = producer.input[0] - consumer = model.find_consumer(bn_output) - if consumer is not None: - if consumer.op_type == "Squeeze": - bn_output = consumer.output[0] - data_shape = model.get_tensor_shape(bn_input) - assert A.ndim == B.ndim, "Unexpected mul/add dims in BatchNormToAffine" - assert len(data_shape) >= A.ndim, "Unexpected number of dims found in BatchNormToAffine" - # reshape the mul/add constants to match the data shape/dims - # by adding (1,) dimensions to the right - n_spatial_dims = len(data_shape) - 2 - target_shape = (1, -1) + tuple(1 for i in range(n_spatial_dims)) - A = A.reshape(target_shape) - B = B.reshape(target_shape) - # create value_info and initializers for Mul and Add constants - mul_const = oh.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, A.shape) - graph.value_info.append(mul_const) - model.set_initializer(mul_const.name, A) - mul_output = oh.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, data_shape) - graph.value_info.append(mul_output) - add_const = oh.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, B.shape) - graph.value_info.append(add_const) - model.set_initializer(add_const.name, B) - # create Mul and Add nodes to replace the batchnorm - mul_node = oh.make_node("Mul", [bn_input, mul_const.name], [mul_output.name]) - add_node = oh.make_node("Add", [mul_output.name, add_const.name], [bn_output]) - # insert where the batchnorm is to preserve topological ordering - graph.node.insert(node_ind, mul_node) - graph.node.insert(node_ind + 1, add_node) - # remove old nodes - graph.node.remove(n) - if consumer is not None: - if consumer.op_type == "Squeeze": - graph.node.remove(consumer) - if producer is not None: - if producer.op_type == "Unsqueeze": - graph.node.remove(producer) + model = ir.from_proto(model.model) + model = rewrite(model, pattern_rewrite_rules=rewrite_rules) + model = ir.to_proto(model) + model = ModelWrapper(model) model = model.transform(InferShapes()) - return (model, graph_modified) + model = model.transform(FoldConstants()) + return (model, False) + diff --git a/src/qonnx/util/onnxscript.py b/src/qonnx/util/onnxscript.py new file mode 100644 index 00000000..34dc6188 --- /dev/null +++ b/src/qonnx/util/onnxscript.py @@ -0,0 +1,51 @@ +# Copyright (c) 2025 Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of AMD nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from onnxscript.rewriter._rewrite_rule import ReplacementPatternFunction, ReplacementSubgraph +from typing import Sequence +from onnxscript.ir import _convenience, _tape + +RewriterContext = _tape.Builder + +class ReplacePattern(ReplacementPatternFunction): + """Utility wrapper that provides matched pattern information to the replacement function. + The matched pattern is passed as the 'match' keyword argument.""" + + def __init__(self, func): + super().__init__(func) + + def get_replacement(self, match): + context = RewriterContext() + new_outputs = self._function(context, match=match, **match.bindings) + if new_outputs is None: + return None + if not isinstance(new_outputs, Sequence): + new_outputs = [new_outputs] + return ReplacementSubgraph( + match, new_outputs, context.nodes, context.initializers, context.used_opsets + ) diff --git a/tests/transformation/test_batchnorm_to_affine.py b/tests/transformation/test_batchnorm_to_affine.py index 705a31c1..d97d83e8 100644 --- a/tests/transformation/test_batchnorm_to_affine.py +++ b/tests/transformation/test_batchnorm_to_affine.py @@ -117,3 +117,7 @@ def test_batchnorm_to_affine_epsilon(epsilon): output_lowered = output_dict[output_node_name] assert (output_original == output_lowered).all() + + op_types = list(map(lambda x: x.op_type, model_lowered.graph.node)) + assert "BatchNormalization" not in op_types +