Skip to content

Commit b56b19d

Browse files
authored
refine dtype convert code (#1092)
1 parent 7bf8107 commit b56b19d

File tree

3 files changed

+27
-51
lines changed

3 files changed

+27
-51
lines changed

neural_compressor/adaptor/ox_utils/quantizer.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def dfs(match_nodes, node, pattern):
271271
if all([i.op_type in ['QuantizeLinear', 'DequantizeLinear'] \
272272
for i in self.model.get_children(match_nodes[0])]):
273273
self.remove_nodes.append(match_nodes[0])
274-
else:
274+
else: # pragma: no cover
275275
parent = self.model.get_parents(match_nodes[0])[0]
276276
children = self.model.get_children(match_nodes[1])
277277
input_dtype = '1' # float32
@@ -326,9 +326,12 @@ def dtype_cast(self, node, cfg, keep_io_types=True): # pragma: no cover
326326
if initializer is not None:
327327
if initializer.data_type != onnx_proto.TensorProto.FLOAT:
328328
continue
329-
cast_tensor(initializer, cfg, min_positive_val, max_finite_val)
330-
self.new_value_info[tensor_name] = ValueInfo(tensor_name,
331-
TensorProto.FLOAT, dtype_mapping[cfg])
329+
new_tensor = cast_tensor(initializer, cfg)
330+
if new_tensor:
331+
self.model.remove_initializer(initializer)
332+
self.model.add_initializer(new_tensor)
333+
self.new_value_info[tensor_name] = ValueInfo(tensor_name,
334+
TensorProto.FLOAT, dtype_mapping[cfg])
332335
else:
333336
if tensor_name in self.value_infos and \
334337
self.value_infos[tensor_name].type.HasField('tensor_type') and \
@@ -681,7 +684,7 @@ def quantize_weights_per_channel(self, node, indices, weight_qType, scheme, axis
681684
qlinear_node = make_quant_node(weight_name + "_QuantizeLinear",
682685
[weight_name, scale_name, zp_name], [weight_name + "_quantized"])
683686
dequant_node = make_dquant_node(weight_name + "_DequantizeLinear",
684-
[weight_name + "_QuantizeLinear", scale_name, zp_name],
687+
[weight_name + "_quantized", scale_name, zp_name],
685688
[weight_name + "_dequantized"])
686689
self.replace_input.append([node, weight_name, dequant_node.output[0]])
687690
self.new_nodes.extend([qlinear_node, dequant_node])

neural_compressor/adaptor/ox_utils/util.py

Lines changed: 12 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import os
2020
import numpy as np
21-
from onnx import helper
21+
from onnx import helper, numpy_helper
2222
from onnx import onnx_pb as onnx_proto
2323
from enum import Enum
2424
from pathlib import Path
@@ -103,58 +103,24 @@ def split_shared_bias(model):
103103
node.input[2] = new_input_name
104104
return model
105105

106-
def convert_np_to_float16(np_array, min_positive_val=1e-7, max_finite_val=1e4): # pragma: no cover
106+
def cast_tensor(tensor, dtype): # pragma: no cover
107107
'''
108-
Convert float32 numpy array to float16 without changing sign or finiteness.
109-
Positive values less than min_positive_val are mapped to min_positive_val.
110-
Positive finite values greater than max_finite_val are mapped to max_finite_val.
111-
Similar for negative values. NaN, 0, inf, and -inf are unchanged.
112-
'''
113-
def between(a, b, c):
114-
return np.logical_and(a < b, b < c)
115-
np_array = np.where(between(0, np_array, min_positive_val), min_positive_val, np_array)
116-
np_array = np.where(between(-min_positive_val, np_array, 0), -min_positive_val, np_array)
117-
np_array = np.where(between(max_finite_val, np_array, float('inf')), max_finite_val, np_array)
118-
np_array = np.where(between(float('-inf'), np_array, -max_finite_val), -max_finite_val, np_array)
119-
return np.float16(np_array)
120-
121-
def _npfloat16_to_int(np_list):
122-
'''
123-
Convert numpy float16 to python int.
124-
param np_list: numpy float16 list
125-
return int_list: python int list
126-
'''
127-
return [int(bin(_.view('H'))[2:].zfill(16), 2) for _ in np_list]
128-
129-
def cast_tensor(tensor, dtype, min_positive_val=1e-7, max_finite_val=1e4): # pragma: no cover
130-
'''
131-
Convert tensor float to float16.
108+
Convert tensor float to target dtype.
132109
param tensor: TensorProto object
133-
return tensor_float16: converted TensorProto object
134-
Example:
135-
from onnxmltools.utils.float16_converter import convert_tensor_float_to_float16
136-
new_tensor = convert_tensor_float_to_float16(tensor)
110+
return tensor_target_dtype: converted TensorProto object
137111
'''
138112
if not isinstance(tensor, onnx_proto.TensorProto):
139113
raise ValueError('Expected input type is an ONNX TensorProto but got %s' % type(tensor))
140114

141115
if tensor.data_type == onnx_proto.TensorProto.FLOAT:
142-
tensor.data_type = onnx_proto.TensorProto.FLOAT16
143-
# convert float_data (float type) to float16 and write to int32_data
144-
if tensor.float_data:
145-
float16_data = convert_np_to_float16(np.array(tensor.float_data), min_positive_val, max_finite_val)
146-
int_list = _npfloat16_to_int(float16_data)
147-
tensor.int32_data[:] = int_list
148-
tensor.float_data[:] = []
149-
# convert raw_data (bytes type)
150-
if tensor.raw_data:
151-
# convert n.raw_data to float
152-
float32_list = np.fromstring(tensor.raw_data, dtype='float32')
153-
# convert float to float16
154-
float16_list = convert_np_to_float16(float32_list, min_positive_val, max_finite_val)
155-
# convert float16 to bytes and write back to raw_data
156-
tensor.raw_data = float16_list.tostring()
157-
return tensor
116+
new_tensor = helper.make_tensor(
117+
name=tensor.name,
118+
data_type=dtype_mapping[dtype],
119+
dims=numpy_helper.to_array(tensor).shape,
120+
vals=numpy_helper.to_array(tensor)
121+
)
122+
return new_tensor
123+
return None
158124

159125
def remove_init_from_model_input(model):
160126
inputs = model.model.graph.input

test/adaptor/onnxrt_adaptor/test_adaptor_onnxrt.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,13 @@ def test_adaptor(self):
830830
from neural_compressor.utils.utility import recover
831831
model = recover(self.ir3_model, './nc_workspace/recover/history.snapshot', 0)
832832
self.assertTrue(model.model == q_model.model)
833+
834+
quantizer = Quantization("qdq.yaml")
835+
quantizer.calib_dataloader = self.matmul_dataloader
836+
quantizer.eval_dataloader = self.matmul_dataloader
837+
quantizer.model = self.matmul_model
838+
q_model = quantizer.fit()
839+
self.assertNotEqual(q_model, None)
833840
options.onnxrt.qdq_setting.AddQDQPairToWeight = False
834841

835842
options.onnxrt.qdq_setting.DedicatedQDQPair = True

0 commit comments

Comments
 (0)