Skip to content

Commit cb58079

Browse files
committed
add UTs for e2e pt2e
Signed-off-by: yiliu30 <yi4.liu@intel.com>
1 parent c1485bf commit cb58079

File tree

1 file changed

+127
-0
lines changed

1 file changed

+127
-0
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import os
2+
import unittest
3+
from unittest.mock import patch
4+
5+
import pytest
6+
import torch
7+
8+
from neural_compressor.common.utils import logger
9+
from neural_compressor.torch.quantization import (
10+
PT2EStaticQuantConfig,
11+
convert,
12+
get_default_pt2e_static_config,
13+
prepare,
14+
quantize,
15+
)
16+
from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version
17+
18+
19+
class TestPT2EQuantization:
20+
21+
@staticmethod
22+
def get_toy_model():
23+
class Bar(torch.nn.Module):
24+
def __init__(self):
25+
super().__init__()
26+
27+
def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
28+
x = a / (torch.abs(a) + 1)
29+
if b.sum() < 0:
30+
b = b * -1
31+
return x * b
32+
33+
inp1 = torch.randn(10)
34+
inp2 = torch.randn(10)
35+
example_inputs = (inp1, inp2)
36+
bar = Bar()
37+
return bar, example_inputs
38+
39+
@staticmethod
40+
def build_simple_torch_model_and_example_inputs():
41+
class SimpleModel(torch.nn.Module):
42+
def __init__(self):
43+
super().__init__()
44+
self.fc1 = torch.nn.Linear(10, 20)
45+
self.fc2 = torch.nn.Linear(20, 10)
46+
47+
def forward(self, x: torch.Tensor) -> torch.Tensor:
48+
x = self.fc1(x)
49+
x = torch.nn.functional.relu(x)
50+
x = self.fc2(x)
51+
return x
52+
53+
model = SimpleModel()
54+
example_inputs = (torch.randn(10, 10),)
55+
return model, example_inputs
56+
57+
@pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0")
58+
def test_quantize_simple_model(self):
59+
model, example_inputs = self.build_simple_torch_model_and_example_inputs()
60+
quant_config = None
61+
62+
def calib_fn(model):
63+
for i in range(2):
64+
model(*example_inputs)
65+
66+
quant_config = get_default_pt2e_static_config()
67+
q_model = quantize(model=model, quant_config=quant_config, example_inputs=example_inputs, run_fn=calib_fn)
68+
from torch._inductor import config
69+
70+
config.freezing = True
71+
opt_model = torch.compile(q_model)
72+
out = opt_model(*example_inputs)
73+
logger.warning("out shape is %s", out.shape)
74+
assert out is not None
75+
76+
@pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0")
77+
def test_prepare_and_convert_on_simple_model(self):
78+
model, example_inputs = self.build_simple_torch_model_and_example_inputs()
79+
quant_config = None
80+
81+
def calib_fn(model):
82+
for i in range(2):
83+
model(*example_inputs)
84+
85+
quant_config = get_default_pt2e_static_config()
86+
87+
prepared_model = prepare(model, quant_config=quant_config, example_inputs=example_inputs)
88+
calib_fn(prepared_model)
89+
q_model = convert(prepared_model)
90+
assert q_model is not None, "Quantization failed!"
91+
92+
from torch._inductor import config
93+
94+
config.freezing = True
95+
opt_model = torch.compile(q_model)
96+
out = opt_model(*example_inputs)
97+
logger.warning("out shape is %s", out.shape)
98+
assert out is not None
99+
100+
@pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0")
101+
def test_prepare_and_convert_on_simple_model_on_llm(self):
102+
from transformers import AutoModelForCausalLM, AutoTokenizer
103+
104+
# set TOKENIZERS_PARALLELISM to false
105+
106+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
107+
108+
model_name = "facebook/opt-125m"
109+
model = AutoModelForCausalLM.from_pretrained(model_name)
110+
tokenizer = AutoTokenizer.from_pretrained(model_name)
111+
input_ids = tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"]
112+
example_inputs = (input_ids,)
113+
quant_config = get_default_pt2e_static_config()
114+
# prepare
115+
prepare_model = prepare(model, quant_config, example_inputs=example_inputs)
116+
# calibrate
117+
for i in range(2):
118+
prepare_model(*example_inputs)
119+
# convert
120+
converted_model = convert(prepare_model)
121+
# inference
122+
from torch._inductor import config
123+
124+
config.freezing = True
125+
opt_model = torch.compile(converted_model)
126+
out = opt_model(*example_inputs)
127+
assert out.logits is not None

0 commit comments

Comments
 (0)