diff --git a/neural_compressor/torch/algorithms/pt2e_quant/__init__.py b/neural_compressor/torch/algorithms/pt2e_quant/__init__.py index 28f108cb636..b3ed5d11fe3 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/__init__.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/__init__.py @@ -11,3 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + +from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8StaticQuantizer diff --git a/neural_compressor/torch/algorithms/pt2e_quant/core.py b/neural_compressor/torch/algorithms/pt2e_quant/core.py index 4ceda55890d..c0983ee7aad 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/core.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/core.py @@ -12,13 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Note - The `W8A8StaticQuantizer` is aligned with with the pytorch-labs/ao's unified quantization API. -# https://github.com/pytorch-labs/ao/blob/5401df093564825c06691f4c2c10cdcf1a32a40c/torchao/quantization/unified.py#L15-L26 # Some code snippets are taken from the X86InductorQuantizer tutorial. # https://pytorch.org/tutorials/prototype/pt2e_quant_x86_inductor.html -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any import torch import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq @@ -28,71 +26,30 @@ from torch.fx.graph_module import GraphModule from neural_compressor.common.utils import logger -from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version +from neural_compressor.torch.algorithms.base_algorithm import Quantizer +from neural_compressor.torch.utils import create_xiq_quantizer_from_pt2e_config -class W8A8StaticQuantizer: +class W8A8StaticQuantizer(Quantizer): @staticmethod - def update_quantizer_based_on_quant_config(quantizer: X86InductorQuantizer, quant_config) -> X86InductorQuantizer: - # TODO: add the logic to update the quantizer based on the quant_config - quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) + def update_quantizer_based_on_quant_config(quant_config=None) -> X86InductorQuantizer: + if not quant_config: + quantizer = X86InductorQuantizer() + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) + else: + quantizer = create_xiq_quantizer_from_pt2e_config(quant_config) return quantizer - @staticmethod - def export_model( - model, - example_inputs: Tuple[Any], - dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, - ) -> Optional[GraphModule]: - exported_model = None - try: - with torch.no_grad(): - # Note 1: `capture_pre_autograd_graph` is also a short-term API, it will be - # updated to use the official `torch.export` API when that is ready. - cur_version = get_torch_version() - if cur_version <= TORCH_VERSION_2_2_2: # pragma: no cover - logger.warning( - ( - "`dynamic_shapes` is not supported in the current version(%s) of PyTorch," - "If you want to use `dynamic_shapes` to export model, " - "please upgrade to 2.3.0 or later." - ), - cur_version, - ) - exported_model = capture_pre_autograd_graph(model, args=example_inputs) - else: # pragma: no cover - exported_model = capture_pre_autograd_graph( # pylint: disable=E1123 - model, args=example_inputs, dynamic_shapes=dynamic_shapes - ) - except Exception as e: - logger.error(f"Failed to export the model: {e}") - return exported_model - - def prepare( - self, model: torch.nn.Module, quant_config, example_inputs: Tuple[Any], *args: Any, **kwargs: Any - ) -> GraphModule: + def prepare(self, model: GraphModule, example_inputs=None, inplace=True, *args, **kwargs) -> GraphModule: """Prepare the model for calibration. - There are two steps in this process: - 1) export the eager model into model with Aten IR. - 2) create the `quantizer` according to the `quant_config`, and insert the observers accordingly. + Create the `quantizer` according to the `quant_config`, and insert the observers accordingly. """ - assert isinstance(example_inputs, tuple), f"Expected `example_inputs` to be a tuple, got {type(example_inputs)}" - # Set the model to eval mode - model = model.eval() - - # 1) Capture the FX Graph to be quantized - dynamic_shapes = kwargs.get("dynamic_shapes", None) - exported_model = self.export_model(model, example_inputs, dynamic_shapes=dynamic_shapes) - logger.info("Exported the model to Aten IR successfully.") - if exported_model is None: - return - - # 2) create the `quantizer` according to the `quant_config`, and insert the observers accordingly. - quantizer = X86InductorQuantizer() - quantizer = self.update_quantizer_based_on_quant_config(quantizer, quant_config) - prepared_model = prepare_pt2e(exported_model, quantizer) + quant_config = self.quant_config + assert model._exported, "The model should be exported before preparing it for calibration." + quantizer = self.update_quantizer_based_on_quant_config(quant_config) + prepared_model = prepare_pt2e(model, quantizer) return prepared_model def convert(self, model: GraphModule, *args: Any, **kwargs: Any) -> GraphModule: diff --git a/neural_compressor/torch/export/__init__.py b/neural_compressor/torch/export/__init__.py new file mode 100644 index 00000000000..6d7af54f5c5 --- /dev/null +++ b/neural_compressor/torch/export/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from neural_compressor.torch.export._export import export_model_for_pt2e_quant, export diff --git a/neural_compressor/torch/export/_export.py b/neural_compressor/torch/export/_export.py new file mode 100644 index 00000000000..579e816894f --- /dev/null +++ b/neural_compressor/torch/export/_export.py @@ -0,0 +1,73 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from torch._export import capture_pre_autograd_graph +from torch.fx.graph_module import GraphModule + +from neural_compressor.common.utils import logger +from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version, is_ipex_imported + +__all__ = ["export", "export_model_for_pt2e_quant"] + + +def export_model_for_pt2e_quant( + model: torch.nn.Module, + example_inputs: Tuple[Any], + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, +) -> Optional[GraphModule]: + """Export the eager model into model with Aten IR.""" + assert isinstance(example_inputs, tuple), f"Expected `example_inputs` to be a tuple, got {type(example_inputs)}" + # Set the model to eval mode + model = model.eval() + exported_model = None + try: + with torch.no_grad(): + # Note 1: `capture_pre_autograd_graph` is also a short-term API, it will be + # updated to use the official `torch.export` API when that is ready. + cur_version = get_torch_version() + if cur_version <= TORCH_VERSION_2_2_2: # pragma: no cover + logger.warning( + ( + "`dynamic_shapes` is not supported in the current version(%s) of PyTorch," + "If you want to use `dynamic_shapes` to export model, " + "please upgrade to 2.3.0 or later." + ), + cur_version, + ) + exported_model = capture_pre_autograd_graph(model, args=example_inputs) + else: + exported_model = capture_pre_autograd_graph( # pylint: disable=E1123 + model, args=example_inputs, dynamic_shapes=dynamic_shapes + ) + exported_model._exported = True + logger.info("Exported the model to Aten IR successfully.") + except Exception as e: + logger.error(f"Failed to export the model: {e}") + + return exported_model + + +def export( + model: torch.nn.Module, + example_inputs: Tuple[Any], + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, +) -> Optional[GraphModule]: + if not is_ipex_imported(): + return export_model_for_pt2e_quant(model, example_inputs, dynamic_shapes) + else: + # TODO, add `export` for ipex + pass diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 41d2593c224..7e344d91ad5 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -30,7 +30,8 @@ StaticQuantConfig, TEQConfig, ) -from neural_compressor.torch.utils import Mode, logger, register_algo +from neural_compressor.torch.utils import Mode, is_ipex_imported, logger, register_algo +from neural_compressor.torch.utils.constants import PT2E_STATIC_QUANT ###################### RTN Algo Entry ################################## @@ -147,6 +148,8 @@ def static_quant_entry( *args, **kwargs, ) -> torch.nn.Module: + if not is_ipex_imported(): + return pt2e_static_quant_entry(model, configs_mapping, mode, *args, **kwargs) logger.info("Quantize model with the static quant algorithm.") from neural_compressor.torch.algorithms.static_quant import StaticQuantQuantizer @@ -191,6 +194,25 @@ def static_quant_entry( return model +###################### PT2E Static Quant Algo Entry ################################## +@register_algo(name=PT2E_STATIC_QUANT) +@torch.no_grad() +def pt2e_static_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, *args, **kwargs) -> torch.nn.Module: + logger.info("Quantize model with the PT2E static quant algorithm.") + from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8StaticQuantizer + + run_fn = kwargs.get("run_fn", None) + example_inputs = kwargs.get("example_inputs", None) + inplace = kwargs.get("inplace", True) + for _, quant_config in configs_mapping.items(): + if quant_config.name == STATIC_QUANT: + w8a8_quantizer = W8A8StaticQuantizer(quant_config=quant_config) + model = w8a8_quantizer.execute( + model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace + ) + return model + + ###################### Smooth Quant Algo Entry ################################## @register_algo(name=SMOOTH_QUANT) @torch.no_grad() diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index a052c923c81..9c1505c06a7 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -17,7 +17,7 @@ # pylint:disable=import-error from collections import OrderedDict -from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union import torch @@ -40,7 +40,7 @@ STATIC_QUANT, TEQ, ) -from neural_compressor.torch.utils import is_hpex_available, logger +from neural_compressor.torch.utils import is_hpex_available, is_ipex_imported, logger from neural_compressor.torch.utils.constants import ( PRIORITY_AUTOROUND, PRIORITY_AWQ, @@ -820,19 +820,31 @@ def __init__( @classmethod def register_supported_configs(cls) -> List[OperatorConfig]: supported_configs = [] - # TODO(Yi) linear_static_config = StaticQuantConfig() operators = [torch.nn.Linear] supported_configs.append(OperatorConfig(config=linear_static_config, operators=operators)) cls.supported_configs = supported_configs @staticmethod - def get_model_info(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]: + def get_model_info_for_ipex(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]: from neural_compressor.torch.algorithms.static_quant import get_quantizable_ops_recursively _, _, _, _, model_info = get_quantizable_ops_recursively(model, example_inputs=example_inputs) return model_info + @staticmethod + def get_model_info(model: torch.nn.Module, example_inputs=None) -> List[Tuple[str, Callable]]: + if is_ipex_imported(): + return StaticQuantConfig.get_model_info_for_ipex(model, example_inputs) + + def to_config_mapping( + self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None + ) -> OrderedDict[Union[str, str], OrderedDict[str, BaseConfig]]: + if is_ipex_imported(): + return super().to_config_mapping(config_list, model_info) + config_mapping = OrderedDict({self.name: self}) + return config_mapping + @classmethod def get_config_set_for_tuning(cls) -> Union[None, "StaticQuantConfig", List["StaticQuantConfig"]]: return StaticQuantConfig(act_sym=[True, False], act_algo=["kl", "minmax"]) @@ -844,6 +856,8 @@ def get_default_static_config() -> StaticQuantConfig: Returns: the default static quant config. """ + if not is_ipex_imported(): + return StaticQuantConfig(w_granularity="per_tensor") return StaticQuantConfig() diff --git a/neural_compressor/torch/utils/constants.py b/neural_compressor/torch/utils/constants.py index 8264130864e..54a68163ded 100644 --- a/neural_compressor/torch/utils/constants.py +++ b/neural_compressor/torch/utils/constants.py @@ -49,3 +49,6 @@ PRIORITY_AWQ = 70 PRIORITY_TEQ = 60 PRIORITY_AUTOROUND = 50 + + +PT2E_STATIC_QUANT = "pt2e_static_quant" diff --git a/neural_compressor/torch/utils/environ.py b/neural_compressor/torch/utils/environ.py index 2d26e34af79..3a98d963e09 100644 --- a/neural_compressor/torch/utils/environ.py +++ b/neural_compressor/torch/utils/environ.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys + import torch from packaging.version import Version @@ -65,6 +67,13 @@ def get_torch_version(): return version +def is_ipex_imported() -> bool: + for name, _ in sys.modules.items(): + if name == "intel_extension_for_pytorch": + return True + return False + + def get_device(device_name="auto"): from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator diff --git a/neural_compressor/torch/utils/utility.py b/neural_compressor/torch/utils/utility.py index 135c4025c10..176be26181d 100644 --- a/neural_compressor/torch/utils/utility.py +++ b/neural_compressor/torch/utils/utility.py @@ -17,6 +17,10 @@ from typing import Callable, Dict, List, Tuple, Union import torch +import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq +from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver +from torch.ao.quantization.quantizer import QuantizationSpec +from torch.ao.quantization.quantizer.x86_inductor_quantizer import QuantizationConfig, X86InductorQuantizer from typing_extensions import TypeAlias from neural_compressor.common import logger @@ -131,3 +135,56 @@ class Mode(Enum): PREPARE = "prepare" CONVERT = "convert" QUANTIZE = "quantize" + + +def create_quant_spec_from_config(dtype, sym, granularity, algo) -> QuantizationSpec: + dtype_mapping: Dict[str, torch.dtype] = {"int8": torch.int8, "uint8": torch.uint8} + qscheme_mapping = { + "per_channel": {True: torch.per_channel_symmetric, False: torch.per_tensor_affine}, + "per_tensor": {True: torch.per_tensor_symmetric, False: torch.per_tensor_affine}, + } + observer_mapping = { + "minmax": MinMaxObserver, + "kl": HistogramObserver, + } + # algo + observer_or_fake_quant_ctr = observer_mapping[algo] + # qscheme + qscheme = qscheme_mapping[granularity][sym] + quantization_spec = QuantizationSpec( + dtype=dtype_mapping[dtype], observer_or_fake_quant_ctr=observer_or_fake_quant_ctr, qscheme=qscheme + ) + return quantization_spec + + +def _map_inc_config_to_torch_quant_config(inc_config) -> QuantizationConfig: + default_quant_config = xiq.get_default_x86_inductor_quantization_config() + input_act_quant_spec = create_quant_spec_from_config( + inc_config.act_dtype, inc_config.act_sym, inc_config.act_granularity, inc_config.act_algo + ) + weight_quant_spec = create_quant_spec_from_config( + inc_config.w_dtype, inc_config.w_sym, inc_config.w_granularity, inc_config.w_algo + ) + quant_config = QuantizationConfig( + input_activation=input_act_quant_spec, + output_activation=default_quant_config.output_activation, + weight=weight_quant_spec, + bias=default_quant_config.bias, + is_qat=False, + ) + return quant_config + + +def create_xiq_quantizer_from_pt2e_config(config) -> X86InductorQuantizer: + quantizer = xiq.X86InductorQuantizer() + # set global + global_config = _map_inc_config_to_torch_quant_config(config) + quantizer.set_global(global_config) + # set local + for module_or_func_name, local_config in config.local_config.items(): + local_quant_config = _map_inc_config_to_torch_quant_config(local_config) + if isinstance(module_or_func_name, torch.nn.Module): + quantizer.set_module_type_qconfig(module_or_func_name, local_quant_config) + else: + quantizer.set_function_type_qconfig(module_or_func_name, local_quant_config) + return quantizer diff --git a/test/3x/torch/algorithms/pt2e_quant/test_pt2e_w8a8.py b/test/3x/torch/algorithms/pt2e_quant/test_pt2e_w8a8.py index 2f48a5a2454..103d0e3350e 100644 --- a/test/3x/torch/algorithms/pt2e_quant/test_pt2e_w8a8.py +++ b/test/3x/torch/algorithms/pt2e_quant/test_pt2e_w8a8.py @@ -6,6 +6,7 @@ from neural_compressor.common.utils import logger from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8StaticQuantizer +from neural_compressor.torch.export import export_model_for_pt2e_quant from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version @@ -45,15 +46,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: model = SimpleModel() example_inputs = (torch.randn(10, 10),) - return model, example_inputs + exported_model = export_model_for_pt2e_quant(model, example_inputs=example_inputs) + return exported_model, example_inputs @pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0") def test_quantizer_on_simple_model(self): model, example_inputs = self.build_simple_torch_model_and_example_inputs() - quant_config = None w8a8_static_quantizer = W8A8StaticQuantizer() # prepare - prepare_model = w8a8_static_quantizer.prepare(model, quant_config, example_inputs=example_inputs) + prepare_model = w8a8_static_quantizer.prepare(model, example_inputs=example_inputs) # calibrate for i in range(2): prepare_model(*example_inputs) @@ -77,10 +78,12 @@ def test_quantizer_on_llm(self): tokenizer = AutoTokenizer.from_pretrained(model_name) input_ids = tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"] example_inputs = (input_ids,) + model = export_model_for_pt2e_quant(model, example_inputs=example_inputs) + quant_config = None w8a8_static_quantizer = W8A8StaticQuantizer() # prepare - prepare_model = w8a8_static_quantizer.prepare(model, quant_config, example_inputs=example_inputs) + prepare_model = w8a8_static_quantizer.prepare(model) # calibrate for i in range(2): prepare_model(*example_inputs) @@ -97,9 +100,8 @@ def test_quantizer_on_llm(self): @patch("neural_compressor.torch.algorithms.pt2e_quant.core.logger.error") def test_export_model_failed(self, mock_error): model, example_inputs = self.get_toy_model() - w8a8_static_quantizer = W8A8StaticQuantizer() # export model - exported_model = w8a8_static_quantizer.export_model(model, example_inputs=example_inputs) + exported_model = export_model_for_pt2e_quant(model, example_inputs=example_inputs) assert exported_model is None call_args_list = mock_error.call_args_list assert any(["Failed to export the model" in msg for msg in [info[0][0] for info in call_args_list]]) diff --git a/test/3x/torch/quantization/test_pt2e_quant.py b/test/3x/torch/quantization/test_pt2e_quant.py new file mode 100644 index 00000000000..23e56d7220b --- /dev/null +++ b/test/3x/torch/quantization/test_pt2e_quant.py @@ -0,0 +1,134 @@ +import os +import unittest +from unittest.mock import patch + +import pytest +import torch + +from neural_compressor.common.utils import logger +from neural_compressor.torch.export import export +from neural_compressor.torch.quantization import ( + StaticQuantConfig, + convert, + get_default_static_config, + prepare, + quantize, +) +from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version, is_ipex_imported + + +class TestPT2EQuantization: + + @staticmethod + def get_toy_model(): + class Bar(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + x = a / (torch.abs(a) + 1) + if b.sum() < 0: + b = b * -1 + return x * b + + inp1 = torch.randn(10) + inp2 = torch.randn(10) + example_inputs = (inp1, inp2) + bar = Bar() + return bar, example_inputs + + @staticmethod + def build_simple_torch_model_and_example_inputs(): + class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 20) + self.fc2 = torch.nn.Linear(20, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc1(x) + x = torch.nn.functional.relu(x) + x = self.fc2(x) + return x + + model = SimpleModel() + example_inputs = (torch.randn(10, 10),) + exported_model = export(model, example_inputs=example_inputs) + return exported_model, example_inputs + + @pytest.mark.skipif(is_ipex_imported(), reason="IPEX is imported") + @pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0") + def test_quantize_simple_model(self): + model, example_inputs = self.build_simple_torch_model_and_example_inputs() + quant_config = None + + def calib_fn(model): + for i in range(2): + model(*example_inputs) + + quant_config = get_default_static_config() + q_model = quantize(model=model, quant_config=quant_config, run_fn=calib_fn) + from torch._inductor import config + + config.freezing = True + opt_model = torch.compile(q_model) + out = opt_model(*example_inputs) + logger.warning("out shape is %s", out.shape) + assert out is not None + + @pytest.mark.skipif(is_ipex_imported(), reason="IPEX is imported") + @pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0") + def test_prepare_and_convert_on_simple_model(self): + model, example_inputs = self.build_simple_torch_model_and_example_inputs() + quant_config = None + + def calib_fn(model): + for i in range(2): + model(*example_inputs) + + quant_config = get_default_static_config() + + prepared_model = prepare(model, quant_config=quant_config) + calib_fn(prepared_model) + q_model = convert(prepared_model) + assert q_model is not None, "Quantization failed!" + + from torch._inductor import config + + config.freezing = True + opt_model = torch.compile(q_model) + out = opt_model(*example_inputs) + logger.warning("out shape is %s", out.shape) + assert out is not None + + @pytest.mark.skipif(is_ipex_imported(), reason="IPEX is imported") + @pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0") + def test_prepare_and_convert_on_llm(self): + from transformers import AutoModelForCausalLM, AutoTokenizer + + # set TOKENIZERS_PARALLELISM to false + + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + model_name = "facebook/opt-125m" + model = AutoModelForCausalLM.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) + input_ids = tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"] + example_inputs = (input_ids,) + model = export(model, example_inputs=example_inputs) + + quant_config = get_default_static_config() + # prepare + prepare_model = prepare(model, quant_config) + # calibrate + for i in range(2): + prepare_model(*example_inputs) + # convert + converted_model = convert(prepare_model) + # inference + from torch._inductor import config + + config.freezing = True + opt_model = torch.compile(converted_model) + out = opt_model(*example_inputs) + assert out.logits is not None