Skip to content

Commit 63f96fb

Browse files
committed
modify op_type for set_local in 3.x API
Signed-off-by: Cheng, Zixuan <zixuan.cheng@intel.com>
1 parent 5f3f388 commit 63f96fb

File tree

3 files changed

+14
-3
lines changed

3 files changed

+14
-3
lines changed

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,13 +341,13 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args):
341341
quant_config = SmoothQuantConfig(alpha=args.alpha, folding=True)
342342

343343
if re.search("gpt", user_model.config.model_type):
344-
quant_config.set_local("add", SmoothQuantConfig(w_dtype="fp32", act_dtype="fp32"))
344+
quant_config.set_local(torch.add, SmoothQuantConfig(w_dtype="fp32", act_dtype="fp32"))
345345
else:
346346
from neural_compressor.torch.quantization import quantize, get_default_static_config, StaticQuantConfig
347347

348348
quant_config = get_default_static_config()
349349
if re.search("gpt", user_model.config.model_type):
350-
quant_config.set_local("add", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
350+
quant_config.set_local(torch.add, StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
351351

352352
from neural_compressor.torch.algorithms.smooth_quant import move_input_to_device
353353
from tqdm import tqdm

test/3x/torch/quantization/test_smooth_quant.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,16 @@ def test_smooth_quant_auto(self):
5555
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
5656
assert q_model is not None, "Quantization failed!"
5757

58+
@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
59+
def test_smooth_quant_fallback(self):
60+
fp32_model = copy.deepcopy(model)
61+
quant_config = get_default_sq_config()
62+
example_inputs = torch.randn([1, 3])
63+
# fallback by op_type
64+
quant_config.set_local(torch.nn.Linear, SmoothQuantConfig(w_dtype="fp32", act_dtype="fp32"))
65+
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
66+
assert q_model is not None, "Quantization failed!"
67+
5868
@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
5969
@pytest.mark.parametrize(
6070
"act_sym, act_algo, alpha, folding, scale_sharing",

test/3x/torch/quantization/test_static_quant.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,12 @@ def test_static_quant_fallback(self):
6161
quant_config = get_default_static_config()
6262
example_inputs = self.input
6363
# fallback by op_type
64-
quant_config.set_local(torch.nn.modules.linear.Linear, StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
64+
quant_config.set_local(torch.nn.Linear, StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
6565
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
6666
assert q_model is not None, "Quantization failed!"
6767

6868
# fallback by op_name
69+
quant_config = get_default_static_config()
6970
quant_config.set_local("fc1", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
7071
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
7172
assert q_model is not None, "Quantization failed!"

0 commit comments

Comments
 (0)