diff --git a/docs/qonnx-custom-ops/trunc_op.md b/docs/qonnx-custom-ops/trunc_op.md index 1b5f0d04..d716c6c2 100644 --- a/docs/qonnx-custom-ops/trunc_op.md +++ b/docs/qonnx-custom-ops/trunc_op.md @@ -6,13 +6,18 @@ The attribute rounding_mode defines how truncated values are rounded. #### Version -This operator is not part of the ONNX standard and is not currently versioned. +This operator is not part of the ONNX standard. +The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 2. #### Attributes
rounding_mode : string (default is "FLOOR")
Defines how rounding should be applied during truncation. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".
+
signed : int (default is 1)
+
Defines if the quantization includes a signed bit. E.g. at 8b unsigned=[0, 255] vs signed=[-128, 127].
+
narrow : int (default is 0)
+
Defines if the value range should be interpreted as narrow, when signed=1. E.g. at 8b regular=[-128, 127] vs narrow=[-127, 127].
#### Inputs @@ -21,11 +26,13 @@ This operator is not part of the ONNX standard and is not currently versioned.
X (differentiable) : tensor(float32)
input tensor to truncate
scale : float32
-
The scale factor
+
The scale factor at the input of the truncation
zeropt : float32
-
The zero-point
+
The zero-point at the input of the truncation
in_bitwidth : int32
The number of bits used at the input of the truncation
+
out_scale : float32
+
The scale factor of the output of the truncation
out_bitwidth : int32
The number of bits used at the output of the truncation
@@ -91,26 +98,32 @@ from __future__ import unicode_literals import numpy as np -def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode): - # Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR +def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode): # Scaling y = inp_tensor / scale y = y + zeropt # Rounding y = np.round(y) - # Truncate - trunc_bit_width = input_bit_width - output_bit_width - trunc_scale = 2.0 ** trunc_bit_width + # Rescale + trunc_scale = 2 ** np.round( + np.log2(output_scale / scale) + ) # Trunc scale should be a power-of-two - ensure that is the case y = y / trunc_scale - # To int + # Clamping + min_int_val = min_int(signed, narrow, output_bit_width) + max_int_val = max_int(signed, narrow, output_bit_width) + y = np.where(y > max_int_val, max_int_val.astype(y.dtype), y) + y = np.where(y < min_int_val, min_int_val.astype(y.dtype), y) + # To int (truncate) rounding_fx = resolve_rounding_mode(rounding_mode) y = rounding_fx(y) # Rescale - y = y - zeropt - y = y * scale + output_zeropt = zeropt / trunc_scale # Rescale zero-point + y = y - output_zeropt + y = y * output_scale return y diff --git a/src/qonnx/custom_op/general/trunc.py b/src/qonnx/custom_op/general/trunc.py index 8e2eaa19..36e60b84 100644 --- a/src/qonnx/custom_op/general/trunc.py +++ b/src/qonnx/custom_op/general/trunc.py @@ -31,10 +31,10 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.general.quant import resolve_rounding_mode +from qonnx.custom_op.general.quant import max_int, min_int, resolve_rounding_mode -def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode): +def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode): # Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR # Scaling @@ -42,27 +42,35 @@ def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding y = y + zeropt # Rounding y = np.round(y) - # Truncate - trunc_bit_width = input_bit_width - output_bit_width - trunc_scale = 2.0**trunc_bit_width + # Rescale + trunc_scale = 2 ** np.round( + np.log2(output_scale / scale) + ) # Trunc scale should be a power-of-two - ensure that is the case y = y / trunc_scale - # To int + # Clamping + min_int_val = min_int(signed, narrow, output_bit_width) + max_int_val = max_int(signed, narrow, output_bit_width) + y = np.where(y > max_int_val, max_int_val.astype(y.dtype), y) + y = np.where(y < min_int_val, min_int_val.astype(y.dtype), y) + # To int (truncate) rounding_fx = resolve_rounding_mode(rounding_mode) y = rounding_fx(y) # Rescale - y = y - zeropt - y = y * scale + output_zeropt = zeropt / trunc_scale # Rescale zero-point + y = y - output_zeropt + y = y * output_scale return y class Trunc(CustomOp): - """Generic truncation operation for QONNX. Takes four inputs: - - input tensor to truncate - - the scale - - the zero-point + """Generic truncation operation for QONNX. Takes four inputs: + - input tensor to truncate + - the scale + - the zero-point + - the truncation scale - the truncation bit-width The output is a tensor of the same shape as the input tensor, with truncated @@ -73,6 +81,8 @@ def get_nodeattr_types(self): return { # The rounding mode, which is used for the trunc function "rounding_mode": ("s", True, "FLOOR"), + "narrow": ("i", False, 0, {0, 1}), + "signed": ("i", False, 1, {0, 1}), } def make_shape_compatible_op(self, model): @@ -90,11 +100,16 @@ def execute_node(self, context, graph): scale = context[node.input[1]] zeropt = context[node.input[2]] input_bit_width = context[node.input[3]] - output_bit_width = context[node.input[4]] + output_scale = context[node.input[4]] + output_bit_width = context[node.input[5]] # save attributes rounding_mode = self.get_nodeattr("rounding_mode") + narrow = self.get_nodeattr("narrow") + signed = self.get_nodeattr("signed") # calculate output - ret = trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode) + ret = trunc( + inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode + ) # set context according to output name context[node.output[0]] = ret