Skip to content

Commit bfce941

Browse files
ClarkChin08ftian1
authored andcommitted
[new] add mse higher is better to 1.2.1
1 parent 55da45e commit bfce941

File tree

5 files changed

+60
-61
lines changed

5 files changed

+60
-61
lines changed

lpot/adaptor/tensorflow.py

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ def __init__(self, framework_specific_info):
6262
os.path.dirname(__file__), "tensorflow.yaml"))
6363
self.op_wise_sequences = self.query_handler.get_eightbit_patterns()
6464

65+
self.fp32_results = []
66+
self.fp32_preds_as_label = False
67+
6568
def log_histogram(self, writer, tag, values, step=0, bins=1000):
6669
import tensorflow as tf
6770
# Convert to a numpy array
@@ -177,9 +180,8 @@ def evaluate(self, model, dataloader, postprocess=None,
177180
outputs.extend(int8_inspect_node_name)
178181

179182
if metric and hasattr(metric, "compare_label") and not metric.compare_label:
180-
results = [[] for _ in range(len(outputs))]
181-
if not os.path.exists(os.path.join(self.work_dir, "output_tensors")):
182-
os.makedirs(os.path.join(self.work_dir, "output_tensors"))
183+
self.fp32_preds_as_label = True
184+
results = []
183185

184186
origin_output_tensor_names = model.output_tensor_names
185187
model.output_tensor_names = outputs
@@ -206,13 +208,10 @@ def evaluate(self, model, dataloader, postprocess=None,
206208
else:
207209
predictions = model.sess.run(output_tensor, feed_dict)
208210

209-
if metric and hasattr(metric, "compare_label") and not metric.compare_label:
210-
if isinstance(predictions, list):
211-
result = [np.array(value) for value in predictions.values()]
212-
else:
213-
result = [predictions]
214-
for i in range(len(outputs)):
215-
results[i].append(result[i])
211+
if self.fp32_preds_as_label:
212+
self.fp32_results.append(predictions) if fp32_baseline else \
213+
results.append(predictions)
214+
216215
# Inspect node output, just get 1st iteration output tensors for now
217216
if idx == 0 and tensorboard:
218217
for index, node_name in enumerate(outputs):
@@ -229,28 +228,22 @@ def evaluate(self, model, dataloader, postprocess=None,
229228
predictions = predictions[:len(origin_output_tensor_names)]
230229
if postprocess is not None:
231230
predictions, labels = postprocess((predictions, labels))
232-
if metric is not None:
233-
if not hasattr(metric, "compare_label"):
234-
metric.update(predictions, labels)
235-
elif hasattr(metric, "compare_label") and metric.compare_label:
236-
metric.update(predictions, labels)
231+
if metric is not None and not self.fp32_preds_as_label:
232+
metric.update(predictions, labels)
237233
if idx + 1 == iteration:
238234
break
239-
if metric:
240-
if hasattr(metric, "compare_label") and not metric.compare_label:
241-
results = [np.array(result) for result in results]
242-
metric.reset()
243-
if fp32_baseline:
244-
np.savez(os.path.join(self.work_dir, "output_tensors", "fp32.npz"),
245-
*results)
246-
metric.update(results, results)
247-
else:
248-
np.savez(os.path.join(self.work_dir, "output_tensors", "int8.npz"),
249-
*results)
250-
reference_file = np.load(os.path.join(self.work_dir, "output_tensors", \
251-
"fp32.npz"), allow_pickle=True)
252-
reference = [reference_file[key] for key in reference_file]
253-
metric.update(reference, results)
235+
236+
if self.fp32_preds_as_label:
237+
from .tf_utils.util import collate_tf_preds
238+
metric.reset()
239+
if fp32_baseline:
240+
results = collate_tf_preds(self.fp32_results)
241+
metric.update(results, results)
242+
else:
243+
reference = collate_tf_preds(self.fp32_results)
244+
results = collate_tf_preds(results)
245+
metric.update(results, reference)
246+
254247
acc = metric.result() if metric is not None else 0
255248
if tensorboard:
256249
new_dir = temp_dir + "_acc_" + str(acc)

lpot/adaptor/tf_utils/util.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,15 +226,19 @@ def iterator_sess_run(sess, iter_op, feed_dict, output_tensor, iteration=-1, mea
226226
except tf.errors.OutOfRangeError:
227227
break
228228

229-
def collate_fn(results):
229+
preds = collate_tf_preds(preds)
230+
return preds
231+
232+
def collate_tf_preds(results):
233+
batch = results[0]
234+
if isinstance(batch, list):
230235
results = zip(*results)
231236
collate_results = []
232237
for output in results:
233238
collate_results.append(np.concatenate(output))
234-
return collate_results
235-
236-
preds = collate_fn(preds)
237-
return preds
239+
elif isinstance(batch, np.ndarray):
240+
collate_results = np.concatenate(results)
241+
return collate_results
238242

239243
def get_input_node_names(graph_def):
240244
g = GraphAnalyzer()

lpot/conf/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def percent_to_float(data):
446446
},
447447
Optional('tuning', default={
448448
'strategy': {'name': 'basic'},
449-
'accuracy_criterion': {'relative': 0.01},
449+
'accuracy_criterion': {'relative': 0.01, 'higher_is_better': True},
450450
'objective': 'performance',
451451
'exit_policy': {'timeout': 0, 'max_trials': 100, 'performance_only': False},
452452
'random_seed': 1978, 'tensorboard': False,
@@ -460,6 +460,7 @@ def percent_to_float(data):
460460
Optional('accuracy_criterion', default={'relative': 0.01}): {
461461
Optional('relative'): And(Or(str, float), Use(percent_to_float)),
462462
Optional('absolute'): And(Or(str, float), Use(percent_to_float)),
463+
Optional('higher_is_better', default=True): bool,
463464
},
464465
Optional('objective', default='performance'): And(str, lambda s: s in OBJECTIVES),
465466
Optional('exit_policy', default={'timeout': 0,

lpot/experimental/metric/metric.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ def register(self, name, metric_cls):
115115
assert name not in self.metrics.keys(), 'registered metric name already exists.'
116116
self.metrics.update({name: metric_cls})
117117

118-
119118
def metric_registry(metric_type, framework):
120119
"""The class decorator used to register all Metric subclasses.
121120
cross framework metric is supported by add param as framework='tensorflow, \
@@ -263,8 +262,8 @@ def _topk_shape_validate(preds, labels):
263262
return preds, labels
264263

265264
def _shape_validate(preds, labels):
266-
assert type(preds) in [int, list], 'preds must be in int or list'
267-
assert type(labels) in [int, list], 'labels must be in int or list'
265+
assert type(preds) in [int, list, np.ndarray], 'preds must be in int or list, ndarray'
266+
assert type(labels) in [int, list, np.ndarray], 'labels must be in int or list, ndarray'
268267
if isinstance(preds, int):
269268
preds = [np.array([preds])]
270269
elif isinstance(preds[0], int):
@@ -471,10 +470,7 @@ def result(self):
471470
aes_sum = sum([np.sum(ae) for ae in aes])
472471
aes_size = sum([ae.size for ae in aes])
473472
assert aes_size, "predictions shouldn't be none"
474-
if self.compare_label:
475-
return aes_sum / aes_size
476-
else:
477-
return 1 / (aes_sum / aes_size + 0.001)
473+
return aes_sum / aes_size
478474

479475
@metric_registry('RMSE', 'tensorflow, mxnet, onnxrt_qlinearops, onnxrt_integerops')
480476
class RMSE(BaseMetric):
@@ -511,10 +507,7 @@ def result(self):
511507
squares_sum = sum([np.sum(square) for square in squares])
512508
squares_size = sum([square.size for square in squares])
513509
assert squares_size, "predictions should't be None"
514-
if self.compare_label:
515-
return squares_sum / squares_size
516-
else:
517-
return 1 / (squares_sum / squares_size + 0.001)
510+
return squares_sum / squares_size
518511

519512
@metric_registry('topk', 'tensorflow')
520513
class TensorflowTopK(BaseMetric):

lpot/objective.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -138,15 +138,19 @@ class Objective(object):
138138
"""
139139

140140
def __init__(self, accuracy_criterion, is_measure=False):
141-
assert isinstance(
142-
accuracy_criterion,
143-
dict) and len(accuracy_criterion) == 1
144-
k, v = list(accuracy_criterion.items())[0]
145-
assert k in ['relative', 'absolute']
146-
assert float(v) < 1 and float(v) > -1
147-
148-
self.acc_goal = float(v)
149-
self.relative = True if k == 'relative' else False
141+
142+
assert isinstance(accuracy_criterion, dict), 'accuracy criterian should be dict'
143+
assert 'relative' in accuracy_criterion or 'absolute' in accuracy_criterion, \
144+
'accuracy criterion should set relative or absolute'
145+
self.higher_is_better = True
146+
for k, v in accuracy_criterion.items():
147+
if k in ['relative', 'absolute']:
148+
assert float(v) < 1 and float(v) > -1
149+
self.relative = True if k == 'relative' else False
150+
self.acc_goal = float(v)
151+
elif k == 'higher_is_better':
152+
self.higher_is_better = bool(v)
153+
150154
self.baseline = None
151155
self.val = None
152156
self.is_measure = is_measure
@@ -158,7 +162,6 @@ def compare(self, last, baseline):
158162
159163
Args:
160164
last (tuple): The tuple of last metric.
161-
accuracy_criterion (float): The allowed accuracy absolute loss.
162165
baseline (tuple): The tuple saving FP32 baseline.
163166
"""
164167
acc, perf = self.val
@@ -169,11 +172,16 @@ def compare(self, last, baseline):
169172
last_measure = 0
170173

171174
base_acc, _ = baseline
175+
176+
if self.relative:
177+
acc_target = base_acc * (1 - float(self.acc_goal)) if self.higher_is_better \
178+
else base_acc * (1 + float(self.acc_goal))
179+
else:
180+
acc_target = base_acc - float(self.acc_goal) if self.higher_is_better \
181+
else base_acc + float(self.acc_goal)
172182

173-
acc_target = base_acc - float(self.acc_goal) if not self.relative \
174-
else base_acc * (1 - float(self.acc_goal))
175-
if acc >= acc_target and (last_measure == 0 or perf < last_measure):
176-
return True
183+
if last_measure == 0 or perf < last_measure:
184+
return acc < acc_target if self.higher_is_better else acc >= acc_target
177185
else:
178186
return False
179187

0 commit comments

Comments
 (0)