Skip to content

Register pt2e static quantization #1761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
May 9, 2024
3 changes: 3 additions & 0 deletions neural_compressor/torch/algorithms/pt2e_quant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8StaticQuantizer
75 changes: 16 additions & 59 deletions neural_compressor/torch/algorithms/pt2e_quant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Note - The `W8A8StaticQuantizer` is aligned with with the pytorch-labs/ao's unified quantization API.
# https://github.com/pytorch-labs/ao/blob/5401df093564825c06691f4c2c10cdcf1a32a40c/torchao/quantization/unified.py#L15-L26
# Some code snippets are taken from the X86InductorQuantizer tutorial.
# https://pytorch.org/tutorials/prototype/pt2e_quant_x86_inductor.html


from typing import Any, Dict, Optional, Tuple, Union
from typing import Any

import torch
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
Expand All @@ -28,71 +26,30 @@
from torch.fx.graph_module import GraphModule

from neural_compressor.common.utils import logger
from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version
from neural_compressor.torch.algorithms.base_algorithm import Quantizer
from neural_compressor.torch.utils import create_xiq_quantizer_from_pt2e_config


class W8A8StaticQuantizer:
class W8A8StaticQuantizer(Quantizer):

@staticmethod
def update_quantizer_based_on_quant_config(quantizer: X86InductorQuantizer, quant_config) -> X86InductorQuantizer:
# TODO: add the logic to update the quantizer based on the quant_config
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
def update_quantizer_based_on_quant_config(quant_config=None) -> X86InductorQuantizer:
if not quant_config:
quantizer = X86InductorQuantizer()
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
else:
quantizer = create_xiq_quantizer_from_pt2e_config(quant_config)
return quantizer

@staticmethod
def export_model(
model,
example_inputs: Tuple[Any],
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
) -> Optional[GraphModule]:
exported_model = None
try:
with torch.no_grad():
# Note 1: `capture_pre_autograd_graph` is also a short-term API, it will be
# updated to use the official `torch.export` API when that is ready.
cur_version = get_torch_version()
if cur_version <= TORCH_VERSION_2_2_2: # pragma: no cover
logger.warning(
(
"`dynamic_shapes` is not supported in the current version(%s) of PyTorch,"
"If you want to use `dynamic_shapes` to export model, "
"please upgrade to 2.3.0 or later."
),
cur_version,
)
exported_model = capture_pre_autograd_graph(model, args=example_inputs)
else: # pragma: no cover
exported_model = capture_pre_autograd_graph( # pylint: disable=E1123
model, args=example_inputs, dynamic_shapes=dynamic_shapes
)
except Exception as e:
logger.error(f"Failed to export the model: {e}")
return exported_model

def prepare(
self, model: torch.nn.Module, quant_config, example_inputs: Tuple[Any], *args: Any, **kwargs: Any
) -> GraphModule:
def prepare(self, model: GraphModule, example_inputs=None, inplace=True, *args, **kwargs) -> GraphModule:
"""Prepare the model for calibration.

There are two steps in this process:
1) export the eager model into model with Aten IR.
2) create the `quantizer` according to the `quant_config`, and insert the observers accordingly.
Create the `quantizer` according to the `quant_config`, and insert the observers accordingly.
"""
assert isinstance(example_inputs, tuple), f"Expected `example_inputs` to be a tuple, got {type(example_inputs)}"
# Set the model to eval mode
model = model.eval()

# 1) Capture the FX Graph to be quantized
dynamic_shapes = kwargs.get("dynamic_shapes", None)
exported_model = self.export_model(model, example_inputs, dynamic_shapes=dynamic_shapes)
logger.info("Exported the model to Aten IR successfully.")
if exported_model is None:
return

# 2) create the `quantizer` according to the `quant_config`, and insert the observers accordingly.
quantizer = X86InductorQuantizer()
quantizer = self.update_quantizer_based_on_quant_config(quantizer, quant_config)
prepared_model = prepare_pt2e(exported_model, quantizer)
quant_config = self.quant_config
assert model._exported, "The model should be exported before preparing it for calibration."
quantizer = self.update_quantizer_based_on_quant_config(quant_config)
prepared_model = prepare_pt2e(model, quantizer)
return prepared_model

def convert(self, model: GraphModule, *args: Any, **kwargs: Any) -> GraphModule:
Expand Down
15 changes: 15 additions & 0 deletions neural_compressor/torch/export/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from neural_compressor.torch.export._export import export_model_for_pt2e_quant, export
73 changes: 73 additions & 0 deletions neural_compressor/torch/export/_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional, Tuple, Union

import torch
from torch._export import capture_pre_autograd_graph
from torch.fx.graph_module import GraphModule

from neural_compressor.common.utils import logger
from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version, is_ipex_imported

__all__ = ["export", "export_model_for_pt2e_quant"]


def export_model_for_pt2e_quant(
model: torch.nn.Module,
example_inputs: Tuple[Any],
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
) -> Optional[GraphModule]:
"""Export the eager model into model with Aten IR."""
assert isinstance(example_inputs, tuple), f"Expected `example_inputs` to be a tuple, got {type(example_inputs)}"
# Set the model to eval mode
model = model.eval()
exported_model = None
try:
with torch.no_grad():
# Note 1: `capture_pre_autograd_graph` is also a short-term API, it will be
# updated to use the official `torch.export` API when that is ready.
cur_version = get_torch_version()
if cur_version <= TORCH_VERSION_2_2_2: # pragma: no cover
logger.warning(
(
"`dynamic_shapes` is not supported in the current version(%s) of PyTorch,"
"If you want to use `dynamic_shapes` to export model, "
"please upgrade to 2.3.0 or later."
),
cur_version,
)
exported_model = capture_pre_autograd_graph(model, args=example_inputs)
else:
exported_model = capture_pre_autograd_graph( # pylint: disable=E1123
model, args=example_inputs, dynamic_shapes=dynamic_shapes
)
exported_model._exported = True
logger.info("Exported the model to Aten IR successfully.")
except Exception as e:
logger.error(f"Failed to export the model: {e}")

return exported_model


def export(
model: torch.nn.Module,
example_inputs: Tuple[Any],
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
) -> Optional[GraphModule]:
if not is_ipex_imported():
return export_model_for_pt2e_quant(model, example_inputs, dynamic_shapes)
else:
# TODO, add `export` for ipex
pass
24 changes: 23 additions & 1 deletion neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
StaticQuantConfig,
TEQConfig,
)
from neural_compressor.torch.utils import Mode, logger, register_algo
from neural_compressor.torch.utils import Mode, is_ipex_imported, logger, register_algo
from neural_compressor.torch.utils.constants import PT2E_STATIC_QUANT


###################### RTN Algo Entry ##################################
Expand Down Expand Up @@ -147,6 +148,8 @@ def static_quant_entry(
*args,
**kwargs,
) -> torch.nn.Module:
if not is_ipex_imported():
return pt2e_static_quant_entry(model, configs_mapping, mode, *args, **kwargs)
logger.info("Quantize model with the static quant algorithm.")
from neural_compressor.torch.algorithms.static_quant import StaticQuantQuantizer

Expand Down Expand Up @@ -191,6 +194,25 @@ def static_quant_entry(
return model


###################### PT2E Static Quant Algo Entry ##################################
@register_algo(name=PT2E_STATIC_QUANT)
@torch.no_grad()
def pt2e_static_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, *args, **kwargs) -> torch.nn.Module:
logger.info("Quantize model with the PT2E static quant algorithm.")
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8StaticQuantizer

run_fn = kwargs.get("run_fn", None)
example_inputs = kwargs.get("example_inputs", None)
inplace = kwargs.get("inplace", True)
for _, quant_config in configs_mapping.items():
if quant_config.name == STATIC_QUANT:
w8a8_quantizer = W8A8StaticQuantizer(quant_config=quant_config)
model = w8a8_quantizer.execute(
model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace
)
return model


###################### Smooth Quant Algo Entry ##################################
@register_algo(name=SMOOTH_QUANT)
@torch.no_grad()
Expand Down
22 changes: 18 additions & 4 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint:disable=import-error

from collections import OrderedDict
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union

import torch

Expand All @@ -40,7 +40,7 @@
STATIC_QUANT,
TEQ,
)
from neural_compressor.torch.utils import is_hpex_available, logger
from neural_compressor.torch.utils import is_hpex_available, is_ipex_imported, logger
from neural_compressor.torch.utils.constants import (
PRIORITY_AUTOROUND,
PRIORITY_AWQ,
Expand Down Expand Up @@ -820,19 +820,31 @@ def __init__(
@classmethod
def register_supported_configs(cls) -> List[OperatorConfig]:
supported_configs = []
# TODO(Yi)
linear_static_config = StaticQuantConfig()
operators = [torch.nn.Linear]
supported_configs.append(OperatorConfig(config=linear_static_config, operators=operators))
cls.supported_configs = supported_configs

@staticmethod
def get_model_info(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]:
def get_model_info_for_ipex(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]:
from neural_compressor.torch.algorithms.static_quant import get_quantizable_ops_recursively

_, _, _, _, model_info = get_quantizable_ops_recursively(model, example_inputs=example_inputs)
return model_info

@staticmethod
def get_model_info(model: torch.nn.Module, example_inputs=None) -> List[Tuple[str, Callable]]:
if is_ipex_imported():
return StaticQuantConfig.get_model_info_for_ipex(model, example_inputs)

def to_config_mapping(
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
) -> OrderedDict[Union[str, str], OrderedDict[str, BaseConfig]]:
if is_ipex_imported():
return super().to_config_mapping(config_list, model_info)
config_mapping = OrderedDict({self.name: self})
return config_mapping

@classmethod
def get_config_set_for_tuning(cls) -> Union[None, "StaticQuantConfig", List["StaticQuantConfig"]]:
return StaticQuantConfig(act_sym=[True, False], act_algo=["kl", "minmax"])
Expand All @@ -844,6 +856,8 @@ def get_default_static_config() -> StaticQuantConfig:
Returns:
the default static quant config.
"""
if not is_ipex_imported():
return StaticQuantConfig(w_granularity="per_tensor")
return StaticQuantConfig()


Expand Down
3 changes: 3 additions & 0 deletions neural_compressor/torch/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,6 @@
PRIORITY_AWQ = 70
PRIORITY_TEQ = 60
PRIORITY_AUTOROUND = 50


PT2E_STATIC_QUANT = "pt2e_static_quant"
9 changes: 9 additions & 0 deletions neural_compressor/torch/utils/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

import torch
from packaging.version import Version

Expand Down Expand Up @@ -65,6 +67,13 @@ def get_torch_version():
return version


def is_ipex_imported() -> bool:
for name, _ in sys.modules.items():
if name == "intel_extension_for_pytorch":
return True
return False


def get_device(device_name="auto"):
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

Expand Down
Loading
Loading