Skip to content

Commit 8c565a4

Browse files
authored
fix corner case in ONNX backend (#1115)
1 parent d2a9905 commit 8c565a4

File tree

6 files changed

+33
-49
lines changed

6 files changed

+33
-49
lines changed

neural_compressor/adaptor/onnxrt.py

Lines changed: 12 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -269,59 +269,29 @@ def _dump_model_op_stats(self, model):
269269
for precision in self.query_handler.get_precisions():
270270
if precision != 'fp32':
271271
fp32_op_list += self.query_handler.get_op_types_by_precision(precision=precision)
272+
qdq_ops = ["QuantizeLinear", "DequantizeLinear", "DynamicQuantizeLinear"]
272273
res = {}
273274
for op_type in fp32_op_list:
274275
res[op_type] = {'INT8':0, 'BF16': 0, 'FP16': 0, 'FP32':0}
275-
for op_type in ["QuantizeLinear", "DequantizeLinear", "DynamicQuantizeLinear"]:
276+
for op_type in qdq_ops:
276277
res[op_type] = {'INT8':0, 'BF16': 0, 'FP16': 0, 'FP32':0}
277278

278-
279-
if self.backend in ["qlinearops", "qdq", "qoperator"] :
280-
int8_op_list = ["QLinearConv", "QLinearMatMul", "QAttention",
281-
"QLinearMul", "QLinearRelu", "QLinearClip",
282-
"QLinearLeakyRelu", "QLinearSigmoid", "MaxPool","Squeeze",
283-
"EmbedLayerNormalization", "QLinearGlobalAveragePool",
284-
"QLinearAdd", "Pad", "Split", "Gather", "Reshape", "Concat",
285-
"QuantizeLinear", "DequantizeLinear", "QLinearAveragePool",
286-
"Unsqueeze", "Transpose"
287-
]
288-
else:
289-
int8_op_list = ["ConvInteger", "MatMulInteger", "QAttention",
290-
"DynamicQuantizeLSTM", "Gather", "EmbedLayerNormalization",
291-
"DynamicQuantizeLinear"
292-
]
293-
294279
for node in model.model.graph.node:
295-
possible_int8_res = [name for name in int8_op_list if node.op_type.find(name) != -1]
296-
297-
if any(possible_int8_res):
280+
if node.name.endswith('_quant'):
298281
if self.backend in ["qlinearops", "qdq", "qoperator"]:
299-
if node.op_type == "QuantizeLinear" or node.op_type == "DequantizeLinear" \
300-
or node.op_type == "DynamicQuantizeLinear":
301-
origin_op_type = node.op_type
302-
else:
303-
origin_op_type = possible_int8_res[0].split('QLinear')[-1]
282+
origin_op_type = node.op_type.split('QLinear')[-1]
304283
else:
305-
origin_op_type = possible_int8_res[0].split('Integer')[0]
306-
307-
if node.op_type in ["Pad", "Split", "Gather", "Concat", "Reshape", "Unsqueeze",
308-
"Squeeze", "Transpose"]:
309-
if any([output.endswith('_quantized') for output in node.output]) or \
310-
any(['_DequantizeLinear' in inp for inp in node.input]):
311-
origin_op_type = node.op_type
312-
else:
313-
if node.op_type in res:
314-
res[node.op_type]['FP32'] += 1
315-
continue
284+
origin_op_type = node.op_type.split('Integer')[0]
316285

317286
if origin_op_type == "QAttention":
318287
origin_op_type = "Attention"
319-
if origin_op_type == "DynamicQuantizeLSTM":
288+
elif origin_op_type == "DynamicQuantizeLSTM":
320289
origin_op_type = "LSTM"
290+
elif origin_op_type == "QEmbedLayerNormalization":
291+
origin_op_type = "EmbedLayerNormalization"
321292
res[origin_op_type]['INT8'] += 1
322293

323-
elif node.op_type in fp32_op_list and \
324-
any(['_DequantizeLinear' in inp for inp in node.input]):
294+
elif node.op_type in qdq_ops:
325295
res[node.op_type]['INT8'] += 1
326296

327297
elif node.op_type in fp32_op_list and node.name in self.quantize_config:
@@ -330,6 +300,9 @@ def _dump_model_op_stats(self, model):
330300
else:
331301
res[node.op_type][self.quantize_config[node.name].upper()] += 1
332302

303+
elif node.op_type in res:
304+
res[node.op_type]['FP32'] += 1
305+
333306
output_data = [[op_type, sum(res[op_type].values()), res[op_type]['INT8'],
334307
res[op_type]['BF16'], res[op_type]['FP16'], res[op_type]['FP32']] for \
335308
op_type in res.keys()]

neural_compressor/adaptor/ox_utils/onnxrt_mid.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ def augment_graph(self, activation_only=False, weight_only=False):
131131
elif not onnx_version < ONNX18_VERSION:
132132
tensors_to_dump.update(node.input)
133133
tensors_to_dump.update(node.output)
134+
if node.op_type == 'EmbedLayerNormalization' and len(node.output) > 1 and \
135+
node.output[2] in tensors_to_dump:
136+
tensors_to_dump.remove(node.output[2])
134137
elif weight_only:
135138
for input in node.input:
136139
if self.already_quantized and \

neural_compressor/adaptor/ox_utils/operators/embed_layernorm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ def convert(self):
4747
[7] mask (int32) (optional)
4848
'''
4949

50-
parents = self.quantizer.model.get_parents(node)
50+
parents = [i for i in self.quantizer.model.get_parents(node) \
51+
if i.op_type == 'DequantizeLinear']
5152
inputs = []
5253
# 'input_ids'
5354
inputs.extend([node.input[0]])

neural_compressor/adaptor/ox_utils/quantizer.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,11 @@ def merge_dedicated_qdq_pair(self):
169169
self.replace_input.append([self.model.get_children(child)[0],
170170
child.output[0], node.input[0]])
171171
self.remove_nodes.append(node)
172+
self.model.remove_nodes(self.remove_nodes)
173+
self.model.graph().node.extend(self.new_nodes)
174+
for node, old_input_name, new_input_name in self.replace_input:
175+
self.model.replace_node_input(node, old_input_name, new_input_name)
176+
self.model.update()
172177
elif self.mode != 'qdq' or not self.dedicated_qdq_pair:
173178
target_type = ['QuantizeLinear', 'DequantizeLinear']
174179
for op_type in target_type:
@@ -190,11 +195,11 @@ def merge_dedicated_qdq_pair(self):
190195
self.replace_input.append([self.model.get_children(dq_nodes[i])[0],
191196
dq_nodes[i].output[0],
192197
dq_nodes[idx].output[0]])
193-
self.model.remove_nodes(self.remove_nodes)
194-
self.model.graph().node.extend(self.new_nodes)
195-
for node, old_input_name, new_input_name in self.replace_input:
196-
self.model.replace_node_input(node, old_input_name, new_input_name)
197-
self.model.update()
198+
self.model.remove_nodes(self.remove_nodes)
199+
self.model.graph().node.extend(self.new_nodes)
200+
for node, old_input_name, new_input_name in self.replace_input:
201+
self.model.replace_node_input(node, old_input_name, new_input_name)
202+
self.model.update()
198203

199204
def should_cast(self, node):
200205
if node.name in self.config and self.config[node.name] != 'fp32': # pragma: no cover
@@ -269,7 +274,8 @@ def dfs(match_nodes, node, pattern):
269274

270275
self.remove_nodes.append(match_nodes[1])
271276
if all([i.op_type in ['QuantizeLinear', 'DequantizeLinear'] \
272-
for i in self.model.get_children(match_nodes[0])]):
277+
for i in self.model.get_children(match_nodes[0])]) and \
278+
match_nodes[0].output[0] not in self.model.output():
273279
self.remove_nodes.append(match_nodes[0])
274280
else: # pragma: no cover
275281
parent = self.model.get_parents(match_nodes[0])[0]

neural_compressor/adaptor/ox_utils/registry.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .operators.qdq_base_operator import QDQOperatorBase
2121
from .operators.matmul import MatMulInteger, QLinearMatMul, QDQMatMul
2222
from .operators.attention import AttentionQuant, QDQAttention
23-
from .operators.embed_layernorm import EmbedLayerNormalizationQuant
23+
from .operators.embed_layernorm import EmbedLayerNormalizationQuant, QDQEmbedLayerNormalization
2424
from .operators.gather import GatherConverter, GatherQuant
2525
from .operators.conv import QLinearConv, ConvInteger, QDQConv
2626
from .operators.activation import QLinearActivation, QDQRemovableActivation, QDQActivation
@@ -92,7 +92,8 @@
9292
"AveragePool": QDQPool,
9393
"Unsqueeze" : QDQDirect8BitOp,
9494
"Concat": QDQConcat,
95-
"Split": QDQSplit
95+
"Split": QDQSplit,
96+
"EmbedLayerNormalization": QDQEmbedLayerNormalization
9697
}
9798

9899
CastRegistry = {

neural_compressor/adaptor/ox_utils/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def split_shared_bias(model):
9191
for node in node_list[1:]:
9292
if node.op_type not in ['Conv', 'FusedConv']:
9393
continue
94-
if node.input[2] == input_name:
94+
if len(node.input) > 2 and node.input[2] == input_name:
9595
new_input_name = node.input[2] + '_nc_split_' + node.name
9696
new_input = helper.make_tensor(
9797
new_input_name,

0 commit comments

Comments
 (0)