Skip to content

Commit 00d6e55

Browse files
ClarkChin08ftian1
authored andcommitted
[fix] fix wide and deep and add onnx new mse higer is better
1 parent f06b6e7 commit 00d6e55

File tree

5 files changed

+69
-38
lines changed

5 files changed

+69
-38
lines changed

lpot/adaptor/onnxrt.py

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ def __init__(self, framework_specific_info):
5252
self.quantizable_op_types = self._query_quantizable_op_types()
5353
self.evaluate_nums = 0
5454

55+
self.fp32_results = []
56+
self.fp32_preds_as_label = False
57+
5558
@dump_elapsed_time("Pass quantize model")
5659
def quantize(self, tune_cfg, model, dataLoader, q_func=None):
5760
"""The function is used to do calibration and quanitization in post-training
@@ -286,12 +289,13 @@ def evaluate(self, input_graph, dataloader, postprocess=None,
286289
"""
287290
session = ort.InferenceSession(input_graph.model.SerializeToString(), None)
288291
len_outputs = len(session.get_outputs())
292+
289293
if metric:
290-
if hasattr(metric, "compare_label"):
291-
if not metric.compare_label:
292-
results = [[] for _ in range(len_outputs)]
293-
if not os.path.exists(os.path.join(self.work_space, "output_tensors")):
294-
os.makedirs(os.path.join(self.work_space, "output_tensors"))
294+
metric.reset()
295+
if hasattr(metric, "compare_label") and not metric.compare_label:
296+
self.fp32_preds_as_label = True
297+
results = []
298+
295299
ort_inputs = {}
296300
len_inputs = len(session.get_inputs())
297301
inputs_names = [session.get_inputs()[i].name for i in range(len_inputs)]
@@ -311,37 +315,28 @@ def evaluate(self, input_graph, dataloader, postprocess=None,
311315
for i in range(len_inputs):
312316
ort_inputs.update({inputs_names[i]: batch[i]})
313317
predictions = session.run(None, ort_inputs)
314-
if metric:
315-
if hasattr(metric, "compare_label"):
316-
if not metric.compare_label:
317-
for i in range(len_outputs):
318-
results[i].append(predictions[i])
318+
319+
if self.fp32_preds_as_label:
320+
self.fp32_results.append(predictions) if fp32_baseline else \
321+
results.append(predictions)
319322

320323
if postprocess is not None:
321324
predictions, labels = postprocess((predictions, labels))
322-
if metric is not None:
323-
if not hasattr(metric, "compare_label"):
324-
metric.update(predictions, labels)
325-
elif hasattr(metric, "compare_label") and metric.compare_label:
326-
metric.update(predictions, labels)
325+
if metric is not None and not self.fp32_preds_as_label:
326+
metric.update(predictions, labels)
327327
if idx + 1 == iteration:
328328
break
329-
if metric:
330-
if hasattr(metric, "compare_label"):
331-
if not metric.compare_label:
332-
metric.reset()
333-
results = [np.array(result) for result in results]
334-
if fp32_baseline:
335-
np.savez(os.path.join(self.work_space,"output_tensors", "fp32.npz"),
336-
*results)
337-
metric.update(results, results)
338-
else:
339-
np.savez(os.path.join(self.work_space,"output_tensors", "int8.npz"),
340-
*results)
341-
reference_file = np.load(os.path.join(self.work_space, "output_tensors", \
342-
"fp32.npz"), allow_pickle=True)
343-
reference = [reference_file[key] for key in reference_file]
344-
metric.update(reference, results)
329+
330+
if self.fp32_preds_as_label:
331+
from .ox_utils.util import collate_preds
332+
if fp32_baseline:
333+
results = collate_preds(self.fp32_results)
334+
metric.update(results, results)
335+
else:
336+
reference = collate_preds(self.fp32_results)
337+
results = collate_preds(results)
338+
metric.update(results, reference)
339+
345340
acc = metric.result() if metric is not None else 0
346341
return acc
347342

lpot/adaptor/ox_utils/util.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2021 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
import os
20+
import numpy as np
21+
22+
def collate_preds(results):
23+
batch = results[0]
24+
if isinstance(batch, list):
25+
results = zip(*results)
26+
collate_results = []
27+
for output in results:
28+
collate_results.append(np.concatenate(output))
29+
elif isinstance(batch, np.ndarray):
30+
collate_results = np.concatenate(results)
31+
return collate_results

lpot/adaptor/tensorflow.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,11 @@ def evaluate(self, model, dataloader, postprocess=None,
179179
output_postfix = "_int8.output"
180180
outputs.extend(int8_inspect_node_name)
181181

182-
if metric and hasattr(metric, "compare_label") and not metric.compare_label:
183-
self.fp32_preds_as_label = True
184-
results = []
182+
if metric:
183+
metric.reset()
184+
if hasattr(metric, "compare_label") and not metric.compare_label:
185+
self.fp32_preds_as_label = True
186+
results = []
185187

186188
origin_output_tensor_names = model.output_tensor_names
187189
model.output_tensor_names = outputs
@@ -235,7 +237,6 @@ def evaluate(self, model, dataloader, postprocess=None,
235237

236238
if self.fp32_preds_as_label:
237239
from .tf_utils.util import collate_tf_preds
238-
metric.reset()
239240
if fp32_baseline:
240241
results = collate_tf_preds(self.fp32_results)
241242
metric.update(results, results)

lpot/model/model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def __init__(self, model, framework_specific_info={}, **kwargs):
329329
self.workspace_path = deep_get(framework_specific_info, 'workspace_path', './')
330330
self.kwargs = copy.deepcopy(kwargs)
331331
kwargs.update({'name': deep_get(framework_specific_info, 'name')})
332-
self.model = model
332+
# self.model = model
333333
self.sess, self._input_tensor_names, self._output_tensor_names = \
334334
create_session_with_input_output(
335335
model, input_tensor_names, output_tensor_names, **kwargs)
@@ -342,6 +342,10 @@ def __init__(self, model, framework_specific_info={}, **kwargs):
342342
if 'MakeIterator' in [node.op for node in self.sess.graph.as_graph_def().node]:
343343
self.iter_op = self.sess.graph.get_operation_by_name('MakeIterator')
344344

345+
@property
346+
def model(self):
347+
return self.sess.graph
348+
345349
@property
346350
def graph_def(self):
347351
return self.sess.graph.as_graph_def()

lpot/quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,9 @@ def eval_func(model):
155155
elif eval_dataloader is not None:
156156
self.exp_quantizer.eval_dataloader = eval_dataloader
157157

158-
if self.exp_quantizer.framework == 'tensorflow':
159-
return self.exp_quantizer().graph
160158
lpot_model = self.exp_quantizer()
159+
if self.exp_quantizer.framework == 'tensorflow':
160+
return lpot_model.graph if lpot_model else None
161161
if self.exp_quantizer.framework == 'pytorch':
162162
saved_path = os.path.abspath(os.path.join(os.path.expanduser(
163163
self.exp_quantizer.conf.usr_cfg.tuning.workspace.path), 'checkpoint'))

0 commit comments

Comments
 (0)