Skip to content

Commit 61a2097

Browse files
authored
Set bias as optional for convolution folding. Needed for CLIP (#1581)
* Set bias as optional for convolution folding. Needed for CLIP * Quality fixes * Merge
1 parent 22e63cd commit 61a2097

File tree

1 file changed

+68
-68
lines changed

1 file changed

+68
-68
lines changed

src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py

+68-68
Original file line numberDiff line numberDiff line change
@@ -743,10 +743,10 @@ def _add_quantized_conv_matmul_add_ops(
743743
weight_quantize_node: NodeProto,
744744
input_quantize_params: QuantizationParams,
745745
weight_quantize_params: QuantizationParams,
746-
bias_initializer: onnx.TensorProto,
747-
bias_add_name: str,
748746
target_output: str,
749747
transpose_weight: bool,
748+
bias_add_name: str,
749+
bias_initializer: Optional[onnx.TensorProto] = None,
750750
output_quantize_node: Optional[NodeProto] = None,
751751
output_dequantize_node: Optional[NodeProto] = None,
752752
):
@@ -806,65 +806,62 @@ def _add_quantized_conv_matmul_add_ops(
806806
)
807807
model.graph.node.append(integer_op_node)
808808

809+
output_scale = input_quantize_params.scale * weight_quantize_params.scale
810+
output_scale_name = "{}_output.scale".format(node.name)
811+
model.graph.initializer.append(
812+
numpy_helper.from_array(numpy.asarray(output_scale), name=output_scale_name)
813+
)
814+
815+
last_output = integer_op_output
816+
809817
# Add bias + zero point correction
810818
# quantize bias
811-
bias_initializer = numpy_helper.to_array(bias_initializer)
812-
bias_scale = input_quantize_params.scale * weight_quantize_params.scale
813-
bias_zero_point = 0
814-
quantized_bias = _quantize_array(
815-
bias_initializer, bias_scale, bias_zero_point, dtype=numpy.int32
816-
)
817-
if node.op_type == "Conv" and len(quantized_bias.shape) == 1:
818-
# reshape for bias add broadcasting
819-
quantized_bias = quantized_bias.reshape(1, quantized_bias.shape[0], 1, 1)
819+
if bias_initializer is not None:
820+
bias_initializer = numpy_helper.to_array(bias_initializer)
820821

821-
quantized_bias_name = "{}.bias_quantized".format(bias_add_name)
822-
quantized_bias_initializer = numpy_helper.from_array(
823-
quantized_bias, name=quantized_bias_name
824-
)
825-
model.graph.initializer.append(quantized_bias_initializer)
826-
quantized_bias_scale_name = "{}.scale".format(quantized_bias_name)
827-
model.graph.initializer.append(
828-
numpy_helper.from_array(
829-
numpy.asarray(bias_scale), name=quantized_bias_scale_name
822+
bias_zero_point = 0
823+
quantized_bias = _quantize_array(
824+
bias_initializer, output_scale, bias_zero_point, dtype=numpy.int32
830825
)
831-
)
832-
quantized_bias_zero_point_name = "{}.zero_point".format(quantized_bias_name)
833-
model.graph.initializer.append(
834-
numpy_helper.from_array(
835-
numpy.asarray(bias_zero_point, dtype=numpy.uint8),
836-
name=quantized_bias_zero_point_name,
826+
if node.op_type == "Conv" and len(quantized_bias.shape) == 1:
827+
# reshape for bias add broadcasting
828+
quantized_bias = quantized_bias.reshape(1, quantized_bias.shape[0], 1, 1)
829+
830+
quantized_bias_name = "{}.bias_quantized".format(bias_add_name)
831+
quantized_bias_initializer = numpy_helper.from_array(
832+
quantized_bias, name=quantized_bias_name
837833
)
838-
)
834+
model.graph.initializer.append(quantized_bias_initializer)
839835

840-
# get INT32 Add inputs and outputs
841-
quant_add_inputs = [
842-
integer_op_output, # MatMul/Conv integer outputs (INT32)
843-
quantized_bias_name, # Quantized bias (INT32)
844-
]
836+
# get INT32 Add inputs and outputs
837+
quant_add_inputs = [
838+
last_output, # MatMul/Conv integer outputs (INT32)
839+
quantized_bias_name, # Quantized bias (INT32)
840+
]
845841

846-
quant_add_name = "{}_bias_add_quant".format(node.name)
847-
quant_add_output = (
848-
output_quantize_node.output[0]
849-
if output_quantize_node
850-
else f"{quant_add_name}_output"
851-
)
842+
quant_add_name = "{}_bias_add_quant".format(node.name)
843+
quant_add_output = (
844+
output_quantize_node.output[0]
845+
if output_quantize_node
846+
else f"{quant_add_name}_output"
847+
)
852848

853-
# create Add node and add it to graph
854-
qadd_node = onnx.helper.make_node(
855-
"Add",
856-
quant_add_inputs,
857-
[quant_add_output],
858-
quant_add_name,
859-
)
860-
model.graph.node.append(qadd_node)
849+
# create Add node and add it to graph
850+
qadd_node = onnx.helper.make_node(
851+
"Add",
852+
quant_add_inputs,
853+
[quant_add_output],
854+
quant_add_name,
855+
)
856+
model.graph.node.append(qadd_node)
857+
last_output = quant_add_output
861858

862859
# create Cast node and add it to graph
863-
cast_node_name = "{}_cast".format(quant_add_name)
864-
cast_node_output = "{}_cast".format(quant_add_output)
860+
cast_node_name = "{}_cast".format(node.name)
861+
cast_node_output = "{}_output".format(cast_node_name)
865862
cast_node = onnx.helper.make_node(
866863
"Cast",
867-
[quant_add_output],
864+
[last_output],
868865
[cast_node_output],
869866
cast_node_name,
870867
to=getattr(onnx.TensorProto, "FLOAT"), # get Float32 enum id
@@ -874,9 +871,9 @@ def _add_quantized_conv_matmul_add_ops(
874871
# create Mul node for rescale
875872
mul_node_inputs = [
876873
cast_node_output, # a
877-
quantized_bias_scale_name, # b -> rescale factor
874+
output_scale_name, # b -> rescale factor
878875
]
879-
mul_node_name = "{}_rescale_mul".format(quant_add_name)
876+
mul_node_name = "{}_rescale_mul".format(cast_node_name)
880877
mul_node = onnx.helper.make_node(
881878
"Mul",
882879
mul_node_inputs,
@@ -979,10 +976,10 @@ def _convert_quantizable_gemm_no_activations(model: ModelProto):
979976
weight_quantize_node=weight_quantize_node,
980977
input_quantize_params=input_quantize_params,
981978
weight_quantize_params=weight_quantize_params,
982-
bias_initializer=bias_initializer,
983-
bias_add_name="{}_bias_add".format(gemm_node.name),
984979
target_output=gemm_node.output[0],
985980
transpose_weight=transpose_weight,
981+
bias_add_name="{}_bias_add".format(gemm_node.name),
982+
bias_initializer=bias_initializer,
986983
)
987984

988985
# Cleanup
@@ -1108,14 +1105,14 @@ def _convert_quantizable_matmul_and_add(model: ModelProto):
11081105
weight_quantize_node=weight_quantize_node,
11091106
input_quantize_params=input_quantize_params,
11101107
weight_quantize_params=weight_quantize_params,
1111-
bias_initializer=bias_initializer,
1112-
bias_add_name=bias_add_node.name,
11131108
target_output=(
11141109
output_dequantize_node.output[0]
11151110
if output_dequantize_node
11161111
else bias_add_node.output[0]
11171112
),
11181113
transpose_weight=True,
1114+
bias_add_name=bias_add_node.name,
1115+
bias_initializer=bias_initializer,
11191116
output_quantize_node=output_quantize_node,
11201117
output_dequantize_node=output_dequantize_node,
11211118
)
@@ -1164,7 +1161,7 @@ def _convert_quantizable_conv_integer(model: ModelProto):
11641161
| | |
11651162
| DequantizeLinear |
11661163
| | |
1167-
| Conv (with bias)
1164+
| Conv (with optional bias)
11681165
| |
11691166
| OUTPUT
11701167
| We end up converting to:
@@ -1174,7 +1171,7 @@ def _convert_quantizable_conv_integer(model: ModelProto):
11741171
| |
11751172
| ConvInteger (with constant uint8 kernel)
11761173
| |
1177-
| Add (constant bias + zero point correction)
1174+
| Add (optional, constant bias + zero point correction)
11781175
| |
11791176
| Cast (INT32 -> FP32)
11801177
| |
@@ -1187,10 +1184,10 @@ def _convert_quantizable_conv_integer(model: ModelProto):
11871184
conv_nodes = [n for n in model.graph.node if n.op_type in ["Conv"]]
11881185
orig_conv_weight_name_to_node_ids = defaultdict(list)
11891186
for conv_node in conv_nodes:
1190-
if len(conv_node.input) != 3:
1191-
# this function currently only converts Conv nodes with bias param
1192-
# (i.e. from folded batch norm value)
1193-
continue
1187+
# if len(conv_node.input) != 3:
1188+
# # this function currently only converts Conv nodes with bias param
1189+
# # (i.e. from folded batch norm value)
1190+
# continue
11941191

11951192
graph = ONNXGraph(model)
11961193

@@ -1226,12 +1223,15 @@ def _convert_quantizable_conv_integer(model: ModelProto):
12261223
if input_quantize_node.op_type != "DequantizeLinear":
12271224
continue
12281225

1229-
bias_initializer = graph.get_init_by_name(conv_node.input[2])
1230-
if bias_initializer is None:
1231-
_LOGGER.debug(f"Unable to find bias initializer: {conv_node.input[2]}")
1232-
continue
1226+
if len(conv_node.input) == 3:
1227+
bias_initializer = graph.get_init_by_name(conv_node.input[2])
1228+
else:
1229+
bias_initializer = None
12331230

1234-
_LOGGER.debug(f"Matched quantizable Conv weight and bias: {conv_node.name}")
1231+
if bias_initializer is None:
1232+
_LOGGER.debug(f"Matched quantizable Conv weight: {conv_node.name}")
1233+
else:
1234+
_LOGGER.debug(f"Matched quantizable Conv weight and bias: {conv_node.name}")
12351235

12361236
# Conversion
12371237
_add_quantized_conv_matmul_add_ops(
@@ -1241,10 +1241,10 @@ def _convert_quantizable_conv_integer(model: ModelProto):
12411241
weight_quantize_node=weight_quantize_node,
12421242
input_quantize_params=input_quantize_params,
12431243
weight_quantize_params=weight_quantize_params,
1244-
bias_initializer=bias_initializer,
1245-
bias_add_name="{}_bias_add".format(conv_node.name),
12461244
target_output=conv_node.output[0],
12471245
transpose_weight=False,
1246+
bias_add_name="{}_bias_add".format(conv_node.name),
1247+
bias_initializer=bias_initializer,
12481248
)
12491249
orig_conv_weight_name_to_node_ids[input_quantize_node.input[0]].append(
12501250
"{}_quant".format(conv_node.output[0])

0 commit comments

Comments
 (0)