Skip to content

Commit 5f3f388

Browse files
Kaihui-intelpre-commit-ci[bot]yiliu30
authored
Add prepare/convert ut for autoround (#1768)
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Yi Liu <106061964+yiliu30@users.noreply.github.com>
1 parent 76b4069 commit 5f3f388

File tree

4 files changed

+77
-35
lines changed

4 files changed

+77
-35
lines changed

neural_compressor/torch/algorithms/weight_only/autoround.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(
8989
scale_dtype (str): The data type of quantization scale to be used (default is "float32"), different kernels
9090
have different choices.
9191
"""
92-
92+
super().__init__(weight_config)
9393
self.tokenizer = None
9494
self.weight_config = weight_config
9595
self.enable_full_range = enable_full_range

neural_compressor/torch/quantization/algorithm_entry.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -369,29 +369,35 @@ def autoround_quantize_entry(
369369
scale_dtype = quant_config.scale_dtype
370370

371371
kwargs.pop("example_inputs")
372-
373-
quantizer = AutoRoundQuantizer(
374-
weight_config=weight_config,
375-
enable_full_range=enable_full_range,
376-
batch_size=batch_size,
377-
lr_scheduler=lr_scheduler,
378-
use_quant_input=use_quant_input,
379-
enable_minmax_tuning=enable_minmax_tuning,
380-
lr=lr,
381-
minmax_lr=minmax_lr,
382-
low_gpu_mem_usage=low_gpu_mem_usage,
383-
iters=iters,
384-
seqlen=seqlen,
385-
n_samples=n_samples,
386-
sampler=sampler,
387-
seed=seed,
388-
n_blocks=n_blocks,
389-
gradient_accumulate_steps=gradient_accumulate_steps,
390-
not_use_best_mse=not_use_best_mse,
391-
dynamic_max_gap=dynamic_max_gap,
392-
scale_dtype=scale_dtype,
393-
)
372+
if getattr(model, "quantizer", False):
373+
quantizer = model.quantizer
374+
else:
375+
quantizer = AutoRoundQuantizer(
376+
weight_config=weight_config,
377+
enable_full_range=enable_full_range,
378+
batch_size=batch_size,
379+
lr_scheduler=lr_scheduler,
380+
use_quant_input=use_quant_input,
381+
enable_minmax_tuning=enable_minmax_tuning,
382+
lr=lr,
383+
minmax_lr=minmax_lr,
384+
low_gpu_mem_usage=low_gpu_mem_usage,
385+
iters=iters,
386+
seqlen=seqlen,
387+
n_samples=n_samples,
388+
sampler=sampler,
389+
seed=seed,
390+
n_blocks=n_blocks,
391+
gradient_accumulate_steps=gradient_accumulate_steps,
392+
not_use_best_mse=not_use_best_mse,
393+
dynamic_max_gap=dynamic_max_gap,
394+
scale_dtype=scale_dtype,
395+
)
394396
model = quantizer.execute(model=model, mode=mode, *args, **kwargs)
397+
if getattr(model, "quantizer", False):
398+
del model.quantizer
399+
else:
400+
model.quantizer = quantizer
395401
logger.info("AutoRound quantization done.")
396402
return model
397403

neural_compressor/torch/quantization/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ def __init__(
685685
gradient_accumulate_steps: int = 1,
686686
not_use_best_mse: bool = False,
687687
dynamic_max_gap: int = -1,
688-
scale_dtype: str = "fp16",
688+
scale_dtype: str = "fp32",
689689
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
690690
):
691691
"""Init AUTOROUND weight-only quantization config.

test/3x/torch/quantization/weight_only/test_autoround.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
1+
import copy
2+
13
import pytest
24
import torch
35
import transformers
46

57
from neural_compressor.torch.algorithms.weight_only.autoround import AutoRoundQuantizer, get_autoround_default_run_fn
6-
from neural_compressor.torch.quantization import AutoRoundConfig, quantize
8+
from neural_compressor.torch.quantization import (
9+
AutoRoundConfig,
10+
convert,
11+
get_default_AutoRound_config,
12+
prepare,
13+
quantize,
14+
)
715
from neural_compressor.torch.utils import logger
816

917
try:
@@ -14,8 +22,7 @@
1422
auto_round_installed = False
1523

1624

17-
@pytest.fixture(scope="module")
18-
def gpt_j():
25+
def get_gpt_j():
1926
tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained(
2027
"hf-internal-testing/tiny-random-GPTJForCausalLM",
2128
torchscript=True,
@@ -25,17 +32,15 @@ def gpt_j():
2532

2633
@pytest.mark.skipif(not auto_round_installed, reason="auto_round module is not installed")
2734
class TestAutoRound:
28-
@staticmethod
29-
@pytest.fixture(scope="class", autouse=True)
30-
def gpt_j_model(gpt_j):
31-
yield gpt_j
35+
def setup_class(self):
36+
self.gptj = get_gpt_j()
3237

3338
def setup_method(self, method):
3439
logger.info(f"Running TestAutoRound test: {method.__name__}")
3540

36-
def test_autoround(self, gpt_j_model):
41+
def test_autoround(self):
3742
inp = torch.ones([1, 10], dtype=torch.long)
38-
43+
gpt_j_model = copy.deepcopy(self.gptj)
3944
tokenizer = transformers.AutoTokenizer.from_pretrained(
4045
"hf-internal-testing/tiny-random-GPTJForCausalLM", trust_remote_code=True
4146
)
@@ -73,9 +78,9 @@ def test_autoround(self, gpt_j_model):
7378
assert "scale" in q_model.autoround_config["transformer.h.0.attn.k_proj"].keys()
7479
assert torch.float32 == q_model.autoround_config["transformer.h.0.attn.k_proj"]["scale_dtype"]
7580

76-
def test_new_api(self, gpt_j_model):
81+
def test_quantizer(self):
7782
inp = torch.ones([1, 10], dtype=torch.long)
78-
83+
gpt_j_model = copy.deepcopy(self.gptj)
7984
tokenizer = transformers.AutoTokenizer.from_pretrained(
8085
"hf-internal-testing/tiny-random-GPTJForCausalLM", trust_remote_code=True
8186
)
@@ -110,3 +115,34 @@ def test_new_api(self, gpt_j_model):
110115
assert "transformer.h.0.attn.k_proj" in q_model.autoround_config.keys()
111116
assert "scale" in q_model.autoround_config["transformer.h.0.attn.k_proj"].keys()
112117
assert torch.float32 == q_model.autoround_config["transformer.h.0.attn.k_proj"]["scale_dtype"]
118+
119+
def test_prepare_and_convert_api(self):
120+
inp = torch.ones([1, 10], dtype=torch.long)
121+
gpt_j_model = copy.deepcopy(self.gptj)
122+
tokenizer = transformers.AutoTokenizer.from_pretrained(
123+
"hf-internal-testing/tiny-random-GPTJForCausalLM", trust_remote_code=True
124+
)
125+
126+
out1 = gpt_j_model(inp)
127+
quant_config = get_default_AutoRound_config()
128+
logger.info(f"Test AutoRound with config {quant_config}")
129+
130+
run_fn = get_autoround_default_run_fn
131+
run_args = (
132+
tokenizer,
133+
"NeelNanda/pile-10k",
134+
20,
135+
10,
136+
)
137+
fp32_model = gpt_j_model
138+
139+
# quantizer execute
140+
model = prepare(model=fp32_model, quant_config=quant_config)
141+
run_fn(model, *run_args)
142+
q_model = convert(model)
143+
144+
out2 = q_model(inp)
145+
assert torch.allclose(out1[0], out2[0], atol=1e-1)
146+
assert "transformer.h.0.attn.k_proj" in q_model.autoround_config.keys()
147+
assert "scale" in q_model.autoround_config["transformer.h.0.attn.k_proj"].keys()
148+
assert torch.float32 == q_model.autoround_config["transformer.h.0.attn.k_proj"]["scale_dtype"]

0 commit comments

Comments
 (0)