Skip to content

Add prepare/convert ut for autoround #1768

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
scale_dtype (str): The data type of quantization scale to be used (default is "float32"), different kernels
have different choices.
"""

super().__init__(weight_config)
self.tokenizer = None
self.weight_config = weight_config
self.enable_full_range = enable_full_range
Expand Down
50 changes: 28 additions & 22 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,29 +374,35 @@ def autoround_quantize_entry(
scale_dtype = quant_config.scale_dtype

kwargs.pop("example_inputs")

quantizer = AutoRoundQuantizer(
weight_config=weight_config,
enable_full_range=enable_full_range,
batch_size=batch_size,
lr_scheduler=lr_scheduler,
use_quant_input=use_quant_input,
enable_minmax_tuning=enable_minmax_tuning,
lr=lr,
minmax_lr=minmax_lr,
low_gpu_mem_usage=low_gpu_mem_usage,
iters=iters,
seqlen=seqlen,
n_samples=n_samples,
sampler=sampler,
seed=seed,
n_blocks=n_blocks,
gradient_accumulate_steps=gradient_accumulate_steps,
not_use_best_mse=not_use_best_mse,
dynamic_max_gap=dynamic_max_gap,
scale_dtype=scale_dtype,
)
if getattr(model, "quantizer", False):
quantizer = model.quantizer
else:
quantizer = AutoRoundQuantizer(
weight_config=weight_config,
enable_full_range=enable_full_range,
batch_size=batch_size,
lr_scheduler=lr_scheduler,
use_quant_input=use_quant_input,
enable_minmax_tuning=enable_minmax_tuning,
lr=lr,
minmax_lr=minmax_lr,
low_gpu_mem_usage=low_gpu_mem_usage,
iters=iters,
seqlen=seqlen,
n_samples=n_samples,
sampler=sampler,
seed=seed,
n_blocks=n_blocks,
gradient_accumulate_steps=gradient_accumulate_steps,
not_use_best_mse=not_use_best_mse,
dynamic_max_gap=dynamic_max_gap,
scale_dtype=scale_dtype,
)
model = quantizer.execute(model=model, mode=mode, *args, **kwargs)
if getattr(model, "quantizer", False):
del model.quantizer
else:
model.quantizer = quantizer
logger.info("AutoRound quantization done.")
return model

Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def __init__(
gradient_accumulate_steps: int = 1,
not_use_best_mse: bool = False,
dynamic_max_gap: int = -1,
scale_dtype: str = "fp16",
scale_dtype: str = "fp32",
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
):
"""Init AUTOROUND weight-only quantization config.
Expand Down
58 changes: 47 additions & 11 deletions test/3x/torch/quantization/weight_only/test_autoround.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import copy

import pytest
import torch
import transformers

from neural_compressor.torch.algorithms.weight_only.autoround import AutoRoundQuantizer, get_autoround_default_run_fn
from neural_compressor.torch.quantization import AutoRoundConfig, quantize
from neural_compressor.torch.quantization import (
AutoRoundConfig,
convert,
get_default_AutoRound_config,
prepare,
quantize,
)
from neural_compressor.torch.utils import logger

try:
Expand All @@ -14,8 +22,7 @@
auto_round_installed = False


@pytest.fixture(scope="module")
def gpt_j():
def get_gpt_j():
tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-GPTJForCausalLM",
torchscript=True,
Expand All @@ -25,17 +32,15 @@ def gpt_j():

@pytest.mark.skipif(not auto_round_installed, reason="auto_round module is not installed")
class TestAutoRound:
@staticmethod
@pytest.fixture(scope="class", autouse=True)
def gpt_j_model(gpt_j):
yield gpt_j
def setup_class(self):
self.gptj = get_gpt_j()

def setup_method(self, method):
logger.info(f"Running TestAutoRound test: {method.__name__}")

def test_autoround(self, gpt_j_model):
def test_autoround(self):
inp = torch.ones([1, 10], dtype=torch.long)

gpt_j_model = copy.deepcopy(self.gptj)
tokenizer = transformers.AutoTokenizer.from_pretrained(
"hf-internal-testing/tiny-random-GPTJForCausalLM", trust_remote_code=True
)
Expand Down Expand Up @@ -73,9 +78,9 @@ def test_autoround(self, gpt_j_model):
assert "scale" in q_model.autoround_config["transformer.h.0.attn.k_proj"].keys()
assert torch.float32 == q_model.autoround_config["transformer.h.0.attn.k_proj"]["scale_dtype"]

def test_new_api(self, gpt_j_model):
def test_quantizer(self):
inp = torch.ones([1, 10], dtype=torch.long)

gpt_j_model = copy.deepcopy(self.gptj)
tokenizer = transformers.AutoTokenizer.from_pretrained(
"hf-internal-testing/tiny-random-GPTJForCausalLM", trust_remote_code=True
)
Expand Down Expand Up @@ -110,3 +115,34 @@ def test_new_api(self, gpt_j_model):
assert "transformer.h.0.attn.k_proj" in q_model.autoround_config.keys()
assert "scale" in q_model.autoround_config["transformer.h.0.attn.k_proj"].keys()
assert torch.float32 == q_model.autoround_config["transformer.h.0.attn.k_proj"]["scale_dtype"]

def test_new_api(self):
inp = torch.ones([1, 10], dtype=torch.long)
gpt_j_model = copy.deepcopy(self.gptj)
tokenizer = transformers.AutoTokenizer.from_pretrained(
"hf-internal-testing/tiny-random-GPTJForCausalLM", trust_remote_code=True
)

out1 = gpt_j_model(inp)
quant_config = get_default_AutoRound_config()
logger.info(f"Test AutoRound with config {quant_config}")

run_fn = get_autoround_default_run_fn
run_args = (
tokenizer,
"NeelNanda/pile-10k",
20,
10,
)
fp32_model = gpt_j_model

# quantizer execute
model = prepare(model=fp32_model, quant_config=quant_config)
run_fn(model, *run_args)
q_model = convert(model)

out2 = q_model(inp)
assert torch.allclose(out1[0], out2[0], atol=1e-1)
assert "transformer.h.0.attn.k_proj" in q_model.autoround_config.keys()
assert "scale" in q_model.autoround_config["transformer.h.0.attn.k_proj"].keys()
assert torch.float32 == q_model.autoround_config["transformer.h.0.attn.k_proj"]["scale_dtype"]
Loading