Skip to content

Commit 4514402

Browse files
Support diagnosis for ONNX NLP models (#1012)
Signed-off-by: Mengni Wang <mengni.wang@intel.com> Co-authored-by: chen, suyue <suyue.chen@intel.com>
1 parent d64d0b7 commit 4514402

File tree

7 files changed

+66
-44
lines changed

7 files changed

+66
-44
lines changed

docs/source/releases_info.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ Contact [inc.maintainers@intel.com](mailto:inc.maintainers@intel.com) if you nee
1717

1818
The MSE tuning strategy does not work with the PyTorch adaptor layer. This strategy requires a comparison between the FP32 and INT8 tensors to decide which op impacts the final quantization accuracy. The PyTorch adaptor layer does not implement this inspect tensor interface. Therefore, do not choose the MSE tuning strategy for PyTorch models.
1919

20+
The diagnosis function does not work with ONNX Runtime 1.13.1 for QDQ format quantization of ONNX models. It can not dump the output value of QDQ pairs since framework limitation.
21+
2022
## Incompatible Changes
2123

2224
[Neural Compressor v1.2](https://github.com/intel/neural-compressor/tree/v1.2) introduces incompatible changes in user facing APIs. Please refer to [incompatible changes](incompatible_changes.md) to know which incompatible changes are made in v1.2.
@@ -25,4 +27,4 @@ The MSE tuning strategy does not work with the PyTorch adaptor layer. This strat
2527

2628
[Neural Compressor v1.7](https://github.com/intel/neural-compressor/tree/v1.7) renames the pip/conda package name from lpot to neural_compressor. To run old examples on latest software, please replace package name for compatibility with `sed -i "s|lpot|neural_compressor|g" your_script.py` .
2729

28-
[Neural Compressor v2.0](https://github.com/intel/neural-compressor/tree/v2.0) renames the `DATASETS` class as `Datasets`, please notice use cases like `from neural_compressor.data import Datasets`. Details please check the [PR](https://github.com/intel/neural-compressor/pull/244/files).
30+
[Neural Compressor v2.0](https://github.com/intel/neural-compressor/tree/v2.0) renames the `DATASETS` class as `Datasets`, please notice use cases like `from neural_compressor.data import Datasets`. Details please check the [PR](https://github.com/intel/neural-compressor/pull/244/files).

neural_compressor/adaptor/onnxrt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,8 @@ def inspect_tensor(self, model, dataloader, op_list=[],
493493
white_nodes=op_list,
494494
backend=self.backend)
495495
tensors = augment.dump_tensor(activation=(inspect_type!='weight'),
496-
weight=(inspect_type!='activation'),)
496+
weight=(inspect_type!='activation'),
497+
format=self.format)
497498
if save_to_disk:
498499
if not save_path:
499500
save_path = self.work_space

neural_compressor/adaptor/ox_utils/calibration.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(self, model_wrapper,
5353
black_nodes=[],
5454
white_nodes=[],
5555
iterations=[],
56-
backend=['CPUExecutionProvider'],
56+
backend='CPUExecutionProvider',
5757
reduce_range=False):
5858
"""Initialization.
5959
@@ -149,7 +149,7 @@ def augment_graph(self, activation_only=False, weight_only=False):
149149
elif not self.already_quantized and input in initializers:
150150
tensors_to_dump.add(input)
151151
elif activation_only:
152-
tensors_to_dump.update(node.output)
152+
tensors_to_dump.update([node.input[0]])
153153

154154
model_inputs = [i.name for i in model.graph.input]
155155
for tensor in tensors_to_dump:
@@ -160,9 +160,7 @@ def augment_graph(self, activation_only=False, weight_only=False):
160160
for augment_node_type in self.augment_nodes:
161161
if augment_node_type in ['DequantizeLinear']:
162162
# insert DequantizeLinear node as output
163-
if tensor.endswith('_scale') or tensor.endswith('_zero_point') or \
164-
tensor.endswith('_QuantizeLinear') or \
165-
tensor.endswith('_QuantizeInput_quantized'):
163+
if tensor.endswith('_scale') or tensor.endswith('_zero_point'):
166164
continue
167165

168166
if not self.dynamically_quantized:
@@ -483,14 +481,16 @@ def calculate_quantization_params(self, q_config, quantization_thresholds):
483481

484482
return quantization_params
485483

486-
def dump_tensor(self, activation=True, weight=False):
484+
def dump_tensor(self, activation=True, weight=False, format=None):
487485
"""Dump activation or weight or both from the model."""
486+
is_qdq = False
488487
if "QuantizeLinear" in [node.op_type for node in self.model.graph.node] or \
489488
"DynamicQuantizeLinear" in [node.op_type for node in self.model.graph.node]:
490489
self.augment_nodes = ["DequantizeLinear"]
491490
self.already_quantized = True
492491
self.dynamically_quantized = \
493492
"DynamicQuantizeLinear" in [node.op_type for node in self.model.graph.node]
493+
is_qdq = format == 'qdq'
494494
self.augment_graph(activation_only=not weight, weight_only=not activation)
495495
_, output_dicts = self.get_intermediate_outputs()
496496
iters = len(list(output_dicts.values())[-1])
@@ -507,30 +507,37 @@ def dump_tensor(self, activation=True, weight=False):
507507
if tensor_name.replace('_dequantized', '_quantized') in model_initializer_names:
508508
nodes = [node for node in map_input[tensor_name] \
509509
if node.name.replace('_quant', '') in self.white_nodes]
510-
elif tensor_name.replace('_quantized', '') in model_input_names:
511-
continue
512-
else:
510+
elif tensor_name in model_output_names:
513511
nodes = [map_output[tensor_name]]
512+
else:
513+
nodes = map_input[tensor_name]
514514
for node in nodes:
515515
node_name = node.name.replace('_quant', '')
516516
if tensor_name in model_output_names and node_name not in self.white_nodes:
517517
continue
518-
while node_name not in self.white_nodes and self.already_quantized:
519-
node = augmengted_wrapper.get_parents(node, output_name_to_node=map_output)[0]
520-
node_name = node.name.replace('_quant', '')
521518
if node_name not in self.white_nodes:
522519
continue
523520
if node_name not in map_node_weight:
524521
map_node_weight[node_name] = {}
525-
if tensor_name not in model_initializer_names:
522+
if ((is_qdq and tensor_name.replace('_dequantized', '_quantized') not in model_initializer_names) or \
523+
(not is_qdq and tensor_name not in model_initializer_names)) and \
524+
tensor_name in node.input[:2]:
526525
for i in range(iters):
527-
map_node_activation[i][node_name] = \
528-
{tensor_name.replace('_quantized', ''): tensors[i]}
529-
elif not (node.op_type in ['Conv', 'Gemm', 'FusedConv'] and tensor_name not in node.input[:2]) and \
526+
if node.op_type in ['Attention', 'QAttention'] and tensor_name not in node.input[:2]:
527+
continue
528+
if is_qdq:
529+
map_node_activation[i][node_name] = \
530+
{tensor_name.replace('_dequantized', '').replace('_' + node_name, ''): tensors[i]}
531+
else:
532+
map_node_activation[i][node_name] = \
533+
{tensor_name.replace('_quantized', ''): tensors[i]}
534+
elif not (node.op_type in ['QGemm'] and tensor_name not in node.input[:6]) and \
530535
not (node.op_type in ['QLinearConv'] and tensor_name not in node.input[:8]) and \
531-
not (node.op_type in ['QGemm'] and tensor_name not in node.input[:6]):
532-
map_node_weight[node_name].update({tensor_name.replace('_quantized', ''): \
533-
tensors[0]})
536+
not (node.op_type in ['Conv', 'Gemm', 'FusedConv'] and tensor_name not in node.input[:2]):
537+
if is_qdq:
538+
map_node_weight[node_name].update({tensor_name.replace('_dequantized', ''): tensors[0]})
539+
else:
540+
map_node_weight[node_name].update({tensor_name.replace('_quantized', ''): tensors[0]})
534541
dumped_tensors_map = {}
535542
if weight:
536543
dumped_tensors_map.update({"weight": map_node_weight})

neural_compressor/adaptor/ox_utils/operators/maxpool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def convert(self, convert_format):
6666
all([i.op_type != 'QuantizeLinear' for i in children]): # pragma: no cover
6767
return
6868
node.input[0] = parent.input[0]
69-
node.output[0] = node.output[0] + '_quantized'
69+
node.output[0] = node.output[0].replace('_QuantizeInput', '_quantized')
7070
for child in children:
7171
if child.op_type == 'QuantizeLinear':
7272
self.quantizer.remove_nodes.append(child)
@@ -82,4 +82,4 @@ class QMaxPoolOperator(QOperator):
8282

8383
def __init__(self, onnx_node, children, initializers):
8484
"""Initialization."""
85-
super().__init__(onnx_node, children, initializers)
85+
super().__init__(onnx_node, children, initializers)

neural_compressor/model/onnx_model.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -329,11 +329,21 @@ def get_scale_zero(self, tensor):
329329
if not tensor.endswith('_quantized'):
330330
logger.debug("Find {} in the quantized graph is not quantized.".format(tensor))
331331
return None, None
332-
input_name_to_nodes = self._input_name_to_nodes
333-
node = input_name_to_nodes[tensor][0]
334-
scale = "_".join(tensor.split('_')[:-1] + ['scale'])
332+
node = self._input_name_to_nodes[tensor][0]
333+
parent = self._output_name_to_node[tensor] if tensor in self._output_name_to_node else None
334+
direct_int8 = ['Reshape', 'Transpose', 'Squeeze', 'Unsqueeze', 'MaxPool', 'Pad']
335+
if parent is not None and parent.op_type in direct_int8:
336+
fp32_tensor_name = \
337+
parent.input[0].replace('_quantized', '').replace('_QuantizeLinear', '').replace('_QuantizeInput', '')
338+
elif node.op_type in ['Gather']:
339+
fp32_tensor_name = \
340+
node.output[0].replace('_quantized', '').replace('_QuantizeLinear', '').replace('_QuantizeInput', '')
341+
else:
342+
fp32_tensor_name = \
343+
tensor.replace('_quantized', '').replace('_QuantizeLinear', '').replace('_QuantizeInput', '')
344+
scale = fp32_tensor_name + '_scale'
335345
scale_tensor = self.get_initializer(scale)
336-
zo = "_".join(tensor.split('_')[:-1] + ['zero_point'])
346+
zo = fp32_tensor_name + '_zero_point'
337347
zo_tensor = self.get_initializer(zo)
338348

339349
#TODO check if scale_tensor and zero_point is needed

test/adaptor/onnxrt_adaptor/test_adaptor_onnxrt.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torchvision
77
import onnx
88
import numpy as np
9+
from packaging.version import Version
910
from collections import OrderedDict
1011
from onnx import onnx_pb as onnx_proto
1112
from onnx import helper, TensorProto, numpy_helper
@@ -731,6 +732,8 @@ def evaluate(self):
731732
with self.assertRaises(ValueError):
732733
test()
733734

735+
@unittest.skipIf(Version(ort.__version__) == Version("1.13.1"),
736+
"This function does not work with ONNX Runtime 1.13.1 for QDQ format quantization of ONNX models.")
734737
def test_inspect_tensor(self):
735738
framework_specific_info = {"device": "cpu",
736739
"approach": "post_training_static_quant",
@@ -774,7 +777,8 @@ def test_inspect_tensor(self):
774777
self.assertTrue(len(fp32_tensor['activation']) == len(int8_tensor['activation']))
775778
self.assertTrue(sorted(fp32_tensor['activation'][0].keys()) == sorted(int8_tensor['activation'][0].keys()))
776779
for op in op_list:
777-
self.assertTrue(sorted(fp32_tensor['activation'][0][op].keys()) == sorted(int8_tensor['activation'][0][op].keys()))
780+
for x, y in zip(fp32_tensor['activation'][0][op].values(), int8_tensor['activation'][0][op].values()):
781+
self.assertTrue(x.shape == y.shape)
778782

779783
if fake_yaml == "qlinear.yaml":
780784
fp32_tensor = quantizer.strategy.adaptor.inspect_tensor(opt_model.model, self.cv_dataloader, op_list, inspect_type='weight')

test/adaptor/onnxrt_adaptor/test_onnxrt_augment.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,9 @@ def test_dump_tensor(self):
124124
white_nodes=["conv"])
125125
map_dumped_tensors = augment.dump_tensor()
126126
assert "conv" in map_dumped_tensors["activation"][0]
127-
assert "C" in map_dumped_tensors["activation"][0]["conv"]
127+
assert "A" in map_dumped_tensors["activation"][0]["conv"]
128128
assert "conv" in map_dumped_tensors["activation"][1]
129-
assert "C" in map_dumped_tensors["activation"][1]["conv"]
129+
assert "A" in map_dumped_tensors["activation"][1]["conv"]
130130

131131
model, dataloader = self.cv_session
132132
augment = ONNXRTAugment(ONNXModel(model),
@@ -321,6 +321,7 @@ def test_augment_graph(self):
321321
# |
322322
# QuantizeLinear
323323

324+
Attention_input = helper.make_tensor_value_info('input_quantized', TensorProto.INT8, [7, 13])
324325
Attention_weight = helper.make_tensor_value_info('weight_quantized', TensorProto.INT8, [13,7])
325326
weight_quantized = generate_input_initializer([13, 7], np.int8, 'weight_quantized')
326327
Attention_bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [13, 7])
@@ -340,7 +341,8 @@ def test_augment_graph(self):
340341
Q_zo = helper.make_tensor_value_info('attn_output_zero_point', TensorProto.INT8, [1])
341342
attn_output_zero_point = generate_input_initializer([1], np.int8, 'attn_output_zero_point')
342343
Output = helper.make_tensor_value_info('output', TensorProto.INT8, [13,7])
343-
attention_node = onnx.helper.make_node('QAttention', ['weight_quantized',
344+
attention_node = onnx.helper.make_node('QAttention', ['input_quantized',
345+
'weight_quantized',
344346
'bias',
345347
'input_scale',
346348
'weight_scale',
@@ -354,7 +356,8 @@ def test_augment_graph(self):
354356
name='attn_output_QuantizeLinear')
355357
graph = helper.make_graph([attention_node, qlinear_node],
356358
'test_graph_5',
357-
[Attention_weight,
359+
[Attention_input,
360+
Attention_weight,
358361
Attention_bias,
359362
Input_scale,
360363
Weight_scale,
@@ -380,14 +383,15 @@ def test_augment_graph(self):
380383
augment = ONNXRTAugment(ONNXModel(model), data_reader, [], white_nodes=['attention'])
381384
augment.augment_nodes = ['DequantizeLinear']
382385
augment.already_quantized = True
386+
383387
augment.augment_graph(activation_only=True, weight_only=False)
384388
augmented_model = augment.augmented_model
385389

386390
augmented_model_node_names = [node.name for node in augmented_model.graph.node]
387391
augmented_model_outputs = [output.name for output in augmented_model.graph.output]
388392
added_node_names = ['attention_quant', 'attn_output_QuantizeLinear']
389-
added_outputs = ['attn_output', 'output']
390-
self.assertEqual(len(augmented_model_node_names), 2)
393+
added_outputs = ['input_quantized_output', 'output']
394+
self.assertEqual(len(augmented_model_node_names), 3)
391395
self.assertEqual(len(augmented_model_outputs), 2)
392396
for name in added_node_names:
393397
self.assertTrue(name in augmented_model_node_names)
@@ -406,10 +410,6 @@ def test_augment_graph(self):
406410
a_scale = generate_input_initializer([1], np.float32, 'A_scale')
407411
A_zo = helper.make_tensor_value_info('A_zero_point', TensorProto.INT8, [1])
408412
a_zero_point = generate_input_initializer([1], np.int8, 'A_zero_point')
409-
B_scale = helper.make_tensor_value_info('B_scale', TensorProto.FLOAT, [1])
410-
b_scale = generate_input_initializer([1], np.float32, 'B_scale')
411-
B_zo = helper.make_tensor_value_info('B_zero_point', TensorProto.INT8, [1])
412-
b_zero_point = generate_input_initializer([1], np.int8, 'B_zero_point')
413413
C = helper.make_tensor_value_info('C', TensorProto.INT8, [1, 1, 5, 5])
414414
c = generate_input_initializer([1, 1, 5, 5], np.int8, 'C')
415415
C_scale = helper.make_tensor_value_info('C_scale', TensorProto.FLOAT, [1])
@@ -423,14 +423,12 @@ def test_augment_graph(self):
423423
D_zo = helper.make_tensor_value_info('D_zero_point', TensorProto.INT8, [1])
424424
d_zero_point = generate_input_initializer([1], np.int8, 'D_zero_point')
425425
D = helper.make_tensor_value_info('D', TensorProto.FLOAT, [1, 1, 5, 5])
426-
quantize_node = onnx.helper.make_node('QuantizeLinear', ['A', 'A_scale', 'A_zero_point'], ['B'], name='A_QuantizeLinear')
427-
conv_node = onnx.helper.make_node('QLinearConv', ['B', 'B_scale', 'B_zero_point', 'C', 'C_scale', 'C_zero_point', 'D_scale', 'D_zero_point', 'E'], ['D_quantized'], name='conv_quant', kernel_shape=[3, 3], pads=[1, 1, 1, 1])
426+
quantize_node = onnx.helper.make_node('QuantizeLinear', ['A', 'A_scale', 'A_zero_point'], ['A_quantized'], name='A_QuantizeLinear')
427+
conv_node = onnx.helper.make_node('QLinearConv', ['A_quantized', 'A_scale', 'A_zero_point', 'C_quantized', 'C_scale', 'C_zero_point', 'D_scale', 'D_zero_point', 'E'], ['D_quantized'], name='conv_quant', kernel_shape=[3, 3], pads=[1, 1, 1, 1])
428428
dequantize_node = onnx.helper.make_node('DequantizeLinear', ['D_quantized', 'D_scale', 'D_zero_point'], ['D'], name='D_DequantizeLinear')
429429
graph = helper.make_graph([quantize_node, conv_node, dequantize_node], 'test_graph_5', [A, A_scale, A_zo, C, C_scale, C_zo, E, D_scale, D_zo], [D])
430430
graph.initializer.add().CopyFrom(a_scale)
431431
graph.initializer.add().CopyFrom(a_zero_point)
432-
graph.initializer.add().CopyFrom(b_scale)
433-
graph.initializer.add().CopyFrom(b_zero_point)
434432
graph.initializer.add().CopyFrom(c)
435433
graph.initializer.add().CopyFrom(c_scale)
436434
graph.initializer.add().CopyFrom(c_zero_point)
@@ -449,8 +447,8 @@ def test_augment_graph(self):
449447

450448
augmented_model_node_names = [node.name for node in augmented_model.graph.node]
451449
augmented_model_outputs = [output.name for output in augmented_model.graph.output]
452-
added_node_names = ['A_QuantizeLinear', 'conv_quant', 'D_DequantizeLinear', 'D_quantized_DequantizeLinear']
453-
added_outputs = ['D', 'D_quantized_output']
450+
added_node_names = ['A_QuantizeLinear', 'conv_quant', 'D_DequantizeLinear', 'A_quantized_DequantizeLinear']
451+
added_outputs = ['D', 'A_quantized_output']
454452
self.assertEqual(len(augmented_model_node_names), 4)
455453
self.assertEqual(len(augmented_model_outputs), 2)
456454
for name in added_node_names:

0 commit comments

Comments
 (0)