Skip to content

Commit 12f2ef2

Browse files
committed
add MSE metric support for pytorch backend
1 parent a61c7d4 commit 12f2ef2

File tree

4 files changed

+229
-10
lines changed

4 files changed

+229
-10
lines changed

lpot/adaptor/pytorch.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,10 @@ def __init__(self, framework_specific_info):
491491
else:
492492
assert False, "Unsupport quantization approach: {}".format(self.approach)
493493

494+
self.fp32_results = []
495+
self.fp32_preds_as_label = False
496+
497+
494498
def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
495499
"""This is a helper function for `query_fw_capability`,
496500
and it will get all quantizable ops from model.
@@ -734,7 +738,13 @@ def evaluate(self, model, dataloader, postprocess=None,
734738
if self.is_baseline:
735739
model_.to("dpcpp")
736740

741+
if metric and hasattr(metric, "compare_label") and not metric.compare_label:
742+
self.fp32_preds_as_label = True
743+
results = []
744+
737745
with torch.no_grad():
746+
if metric:
747+
metric.reset()
738748
for idx, (input, label) in enumerate(dataloader):
739749
if measurer is not None:
740750
measurer.start()
@@ -760,10 +770,24 @@ def evaluate(self, model, dataloader, postprocess=None,
760770
measurer.end()
761771
if postprocess is not None:
762772
output, label = postprocess((output, label))
763-
if metric is not None:
773+
if metric is not None and not self.fp32_preds_as_label:
764774
metric.update(output, label)
775+
if self.fp32_preds_as_label:
776+
self.fp32_results.append(output) if fp32_baseline else \
777+
results.append(output)
765778
if idx + 1 == iteration:
766779
break
780+
781+
if self.fp32_preds_as_label:
782+
from .torch_utils.util import collate_torch_preds
783+
if fp32_baseline:
784+
results = collate_torch_preds(self.fp32_results)
785+
metric.update(results, results)
786+
else:
787+
reference = collate_torch_preds(self.fp32_results)
788+
results = collate_torch_preds(results)
789+
metric.update(results, reference)
790+
767791
acc = metric.result() if metric is not None else 0
768792

769793
if tensorboard:

lpot/adaptor/torch_utils/util.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import numpy as np
2+
import torch
3+
4+
def collate_torch_preds(results):
5+
batch = results[0]
6+
if isinstance(batch, list):
7+
results = zip(*results)
8+
collate_results = []
9+
for output in results:
10+
output = [batch.numpy() if isinstance(batch, torch.Tensor) else batch for batch in output]
11+
collate_results.append(np.concatenate(output))
12+
elif isinstance(batch, torch.Tensor):
13+
results = [batch.numpy() if isinstance(batch, torch.Tensor) else batch for batch in results]
14+
collate_results = np.concatenate(results)
15+
return collate_results

lpot/experimental/metric/metric.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,6 @@ def __init__(self):
4343
torch_ignite.metrics.Accuracy),
4444
"Loss": WrapPyTorchMetric(
4545
PyTorchLoss),
46-
"MAE": WrapPyTorchMetric(
47-
torch_ignite.metrics.MeanAbsoluteError),
48-
"RMSE": WrapPyTorchMetric(
49-
torch_ignite.metrics.RootMeanSquaredError),
50-
"MSE": WrapPyTorchMetric(
51-
torch_ignite.metrics.MeanSquaredError),
5246
}
5347
self.metrics.update(PYTORCH_METRICS)
5448

@@ -449,7 +443,7 @@ def reset(self):
449443
def result(self):
450444
return self.sum / self.sample
451445

452-
@metric_registry('MAE', 'tensorflow, onnxrt_qlinearops, onnxrt_integerops')
446+
@metric_registry('MAE', 'tensorflow, pytorch, onnxrt_qlinearops, onnxrt_integerops')
453447
class MAE(BaseMetric):
454448
def __init__(self, compare_label=True):
455449
self.label_list = []
@@ -472,7 +466,7 @@ def result(self):
472466
assert aes_size, "predictions shouldn't be none"
473467
return aes_sum / aes_size
474468

475-
@metric_registry('RMSE', 'tensorflow, mxnet, onnxrt_qlinearops, onnxrt_integerops')
469+
@metric_registry('RMSE', 'tensorflow, pytorch, mxnet, onnxrt_qlinearops, onnxrt_integerops')
476470
class RMSE(BaseMetric):
477471
def __init__(self, compare_label=True):
478472
self.mse = MSE(compare_label)
@@ -486,7 +480,7 @@ def reset(self):
486480
def result(self):
487481
return np.sqrt(self.mse.result())
488482

489-
@metric_registry('MSE', 'tensorflow, onnxrt_qlinearops, onnxrt_integerops')
483+
@metric_registry('MSE', 'tensorflow, pytorch, onnxrt_qlinearops, onnxrt_integerops')
490484
class MSE(BaseMetric):
491485
def __init__(self, compare_label=True):
492486
self.label_list = []

test/test_mse_metric.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
2+
import torch
3+
import torch.nn.quantized as nnq
4+
from torch.quantization import QuantStub, DeQuantStub
5+
import torchvision
6+
import unittest
7+
import os
8+
from lpot.adaptor import FRAMEWORKS
9+
from lpot.model import MODELS
10+
import lpot.adaptor.pytorch as lpot_torch
11+
from lpot.experimental import Quantization, common
12+
import shutil
13+
import copy
14+
import numpy as np
15+
16+
try:
17+
import intel_pytorch_extension as ipex
18+
TEST_IPEX = True
19+
except:
20+
TEST_IPEX = False
21+
22+
torch.manual_seed(1)
23+
24+
def build_ptq_yaml():
25+
fake_yaml = '''
26+
model:
27+
name: imagenet
28+
framework: pytorch
29+
30+
evaluation:
31+
accuracy:
32+
metric:
33+
MSE:
34+
compare_label: False
35+
performance:
36+
warmup: 5
37+
iteration: 10
38+
39+
tuning:
40+
accuracy_criterion:
41+
absolute: 100.0
42+
higher_is_better: False
43+
exit_policy:
44+
timeout: 0
45+
random_seed: 9527
46+
workspace:
47+
path: saved
48+
'''
49+
with open('ptq_yaml.yaml', 'w', encoding="utf-8") as f:
50+
f.write(fake_yaml)
51+
52+
53+
def build_dynamic_yaml():
54+
fake_yaml = '''
55+
model:
56+
name: imagenet
57+
framework: pytorch
58+
59+
quantization:
60+
approach: post_training_dynamic_quant
61+
evaluation:
62+
accuracy:
63+
metric:
64+
MSE:
65+
compare_label: False
66+
performance:
67+
warmup: 5
68+
iteration: 10
69+
70+
tuning:
71+
accuracy_criterion:
72+
absolute: 100.0
73+
higher_is_better: False
74+
exit_policy:
75+
timeout: 0
76+
random_seed: 9527
77+
workspace:
78+
path: saved
79+
'''
80+
with open('dynamic_yaml.yaml', 'w', encoding="utf-8") as f:
81+
f.write(fake_yaml)
82+
83+
84+
def build_ipex_yaml():
85+
fake_yaml = '''
86+
model:
87+
name: imagenet
88+
framework: pytorch_ipex
89+
90+
evaluation:
91+
accuracy:
92+
metric:
93+
MSE:
94+
compare_label: False
95+
performance:
96+
warmup: 5
97+
iteration: 10
98+
99+
tuning:
100+
accuracy_criterion:
101+
relative: 0.01
102+
exit_policy:
103+
timeout: 0
104+
random_seed: 9527
105+
workspace:
106+
path: saved
107+
'''
108+
with open('ipex_yaml.yaml', 'w', encoding="utf-8") as f:
109+
f.write(fake_yaml)
110+
111+
112+
@unittest.skipIf(TEST_IPEX, "TODO: Please wait to IPEX + PyTorch1.7 release")
113+
class TestPytorchAdaptor(unittest.TestCase):
114+
framework_specific_info = {"device": "cpu",
115+
"approach": "post_training_static_quant",
116+
"random_seed": 1234,
117+
"q_dataloader": None}
118+
framework = "pytorch"
119+
adaptor = FRAMEWORKS[framework](framework_specific_info)
120+
model = torchvision.models.quantization.resnet18()
121+
lpot_model = MODELS['pytorch'](model)
122+
123+
@classmethod
124+
def setUpClass(self):
125+
build_ptq_yaml()
126+
build_dynamic_yaml()
127+
128+
@classmethod
129+
def tearDownClass(self):
130+
os.remove('ptq_yaml.yaml')
131+
os.remove('dynamic_yaml.yaml')
132+
shutil.rmtree('./saved', ignore_errors=True)
133+
shutil.rmtree('runs', ignore_errors=True)
134+
135+
def test_quantization_saved(self):
136+
from lpot.utils.pytorch import load
137+
138+
for fake_yaml in ['dynamic_yaml.yaml', 'ptq_yaml.yaml']:
139+
if fake_yaml == 'dynamic_yaml.yaml':
140+
model = torchvision.models.quantization.resnet18()
141+
else:
142+
model = copy.deepcopy(self.model)
143+
if fake_yaml == 'ptq_yaml.yaml':
144+
model.eval().fuse_model()
145+
quantizer = Quantization(fake_yaml)
146+
dataset = quantizer.dataset('dummy', (100, 3, 256, 256), label=True)
147+
quantizer.model = common.Model(model)
148+
quantizer.calib_dataloader = common.DataLoader(dataset)
149+
quantizer.eval_dataloader = common.DataLoader(dataset)
150+
q_model = quantizer()
151+
self.assertTrue(bool(q_model))
152+
153+
@unittest.skipIf(not TEST_IPEX, "Unsupport Intel PyTorch Extension")
154+
class TestPytorchIPEXAdaptor(unittest.TestCase):
155+
@classmethod
156+
def setUpClass(self):
157+
build_ipex_yaml()
158+
159+
@classmethod
160+
def tearDownClass(self):
161+
os.remove('ipex_yaml.yaml')
162+
shutil.rmtree('./saved', ignore_errors=True)
163+
shutil.rmtree('runs', ignore_errors=True)
164+
def test_tuning_ipex(self):
165+
from lpot.experimental import Quantization
166+
model = torchvision.models.resnet18()
167+
quantizer = Quantization('ipex_yaml.yaml')
168+
dataset = quantizer.dataset('dummy', (100, 3, 256, 256), label=True)
169+
quantizer.model = common.Model(model)
170+
quantizer.calib_dataloader = common.DataLoader(dataset)
171+
quantizer.eval_dataloader = common.DataLoader(dataset)
172+
lpot_model = quantizer()
173+
lpot_model.save("./saved")
174+
try:
175+
script_model = torch.jit.script(model.to(ipex.DEVICE))
176+
except:
177+
script_model = torch.jit.trace(model.to(ipex.DEVICE), torch.randn(10, 3, 224, 224).to(ipex.DEVICE))
178+
from lpot.experimental import Benchmark
179+
evaluator = Benchmark('ipex_yaml.yaml')
180+
evaluator.model = common.Model(script_model)
181+
evaluator.b_dataloader = common.DataLoader(dataset)
182+
results = evaluator()
183+
184+
185+
if __name__ == "__main__":
186+
unittest.main()

0 commit comments

Comments
 (0)