Skip to content

Commit 01513a3

Browse files
nirda7ulivne
andauthored
Support FP8 Quantization and Inference Run on Intel Gaudi (HPU) using INC (Intel Neural Compressor) (#12010)
Signed-off-by: Nir David <ndavid@habana.ai> Signed-off-by: Uri Livne <ulivne@habana.ai> Co-authored-by: Uri Livne <ulivne@habana.ai>
1 parent ac2bf41 commit 01513a3

File tree

11 files changed

+168
-25
lines changed

11 files changed

+168
-25
lines changed

docs/features/quantization/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Contents:
1010
- [BitBLAS](bitblas.md)
1111
- [GGUF](gguf.md)
1212
- [GPTQModel](gptqmodel.md)
13+
- [INC](inc.md)
1314
- [INT4 W4A16](int4.md)
1415
- [INT8 W8A8](int8.md)
1516
- [FP8 W8A8](fp8.md)

docs/features/quantization/inc.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
---
2+
title: FP8 INC
3+
---
4+
[](){ #inc }
5+
6+
vLLM supports FP8 (8-bit floating point) weight and activation quantization using Intel® Neural Compressor (INC) on Intel® Gaudi® 2 and Intel® Gaudi® 3 AI accelerators.
7+
Currently, quantization is validated only in Llama models.
8+
9+
Intel Gaudi supports quantization of various modules and functions, including, but not limited to `Linear`, `KVCache`, `Matmul` and `Softmax`. For more information, please refer to:
10+
[Supported Modules\\Supported Functions\\Custom Patched Modules](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-modules).
11+
12+
!!! note
13+
Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vllm-hpu-extention](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package.
14+
15+
!!! note
16+
`QUANT_CONFIG` is an environment variable that points to the measurement or quantization [JSON config file](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-json-config-file-options).
17+
The measurement configuration file is used during the calibration procedure to collect measurements for a given model. The quantization configuration is used during inference.
18+
19+
## Run Online Inference Using FP8
20+
21+
Once you've completed the model calibration process and collected the measurements, you can run FP8 inference with vLLM using the following command:
22+
23+
```bash
24+
export QUANT_CONFIG=/path/to/quant/config/inc/meta-llama-3.1-405b-instruct/maxabs_measure_g3.json
25+
vllm serve meta-llama/Llama-3.1-405B-Instruct --quantization inc --kv-cache-dtype fp8_inc --tensor_paralel_size 8
26+
```
27+
28+
!!! tip
29+
If you are just prototyping or testing your model with FP8, you can use the `VLLM_SKIP_WARMUP=true` environment variable to disable the warmup stage, which can take a long time. However, we do not recommend disabling this feature in production environments as it causes a significant performance drop.
30+
31+
!!! tip
32+
When using FP8 models, you may experience timeouts caused by the long compilation time of FP8 operations. To mitigate this problem, you can use the below environment variables:
33+
`VLLM_ENGINE_ITERATION_TIMEOUT_S` - to adjust the vLLM server timeout. You can set the value in seconds, e.g., 600 equals 10 minutes.
34+
`VLLM_RPC_TIMEOUT` - to adjust the RPC protocol timeout used by the OpenAI-compatible API. This value is in microseconds, e.g., 600000 equals 10 minutes.
35+
36+
## Run Offline Inference Using FP8
37+
38+
To run offline inference (after completing the model calibration process):
39+
40+
* Set the "QUANT_CONFIG" environment variable to point to a JSON configuration file with QUANTIZE mode.
41+
* Pass `quantization=inc` and `kv_cache_dtype=fp8_inc` as parameters to the `LLM` object.
42+
* Call shutdown method of the model_executor at the end of the run.
43+
44+
```python
45+
from vllm import LLM
46+
llm = LLM("llama3.1/Meta-Llama-3.1-8B-Instruct", quantization="inc", kv_cache_dtype="fp8_inc")
47+
...
48+
# Call llm.generate on the required prompts and sampling params.
49+
...
50+
llm.llm_engine.model_executor.shutdown()
51+
```
52+
53+
## Device for the Model's Weights Uploading
54+
55+
The unquantized weights are first loaded onto the CPU, then quantized and transferred to the target device (HPU) for model execution.
56+
This reduces the device memory footprint of model weights, as only quantized weights are stored in the device memory.

docs/features/quantization/supported_hardware.md

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,19 @@
22

33
The table below shows the compatibility of various quantization implementations with different hardware platforms in vLLM:
44

5-
| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | x86 CPU | AWS Neuron | Google TPU |
6-
|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-----------|------------------|--------------|
7-
| AWQ || ✅︎ | ✅︎ | ✅︎ | ✅︎ || ✅︎ | ✅︎ |||
8-
| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ || ✅︎ | ✅︎ |||
9-
| Marlin (GPTQ/AWQ/FP8) ||| ✅︎ | ✅︎ | ✅︎ ||||||
10-
| INT8 (W8A8) || ✅︎ | ✅︎ | ✅︎ | ✅︎ ||| ✅︎ | ✅︎ | ✅︎ |
11-
| FP8 (W8A8) |||| ✅︎ | ✅︎ | ✅︎ ||| ✅︎ ||
12-
| BitBLAS (GPTQ) | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ ||||||
13-
| AQLM | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ ||||||
14-
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ ||||||
15-
| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ ||||||
16-
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ |||||
5+
| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU | AWS Neuron | Google TPU |
6+
|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|--------------|--------------|
7+
| AWQ || ✅︎ | ✅︎ | ✅︎ | ✅︎ || ✅︎ || ✅︎ |||
8+
| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ || ✅︎ || ✅︎ |||
9+
| Marlin (GPTQ/AWQ/FP8) ||| ✅︎ | ✅︎ | ✅︎ |||||||
10+
| INT8 (W8A8) || ✅︎ | ✅︎ | ✅︎ | ✅︎ |||| ✅︎ | ✅︎ | ✅︎ |
11+
| FP8 (W8A8) |||| ✅︎ | ✅︎ | ✅︎ |||| ✅︎ ||
12+
| BitBLAS (GPTQ) | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ |||||||
13+
| AQLM | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ |||||||
14+
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ |||||||
15+
| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ |||||||
16+
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ ||||||
17+
| INC (W8A8) |||||||| ✅︎ ||||
1718

1819
- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0.
1920
- ✅︎ indicates that the quantization method is supported on the specified hardware.

docs/getting_started/installation/intel_gaudi.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ To verify that the Intel Gaudi software was correctly installed, run:
2828
hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible
2929
apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core, habanalabs-thunk and habanalabs-container-runtime are installed
3030
pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed
31-
pip list | grep neural # verify that neural_compressor is installed
31+
pip list | grep neural # verify that neural_compressor_pt is installed
3232
```
3333

3434
Refer to [Intel Gaudi Software Stack Verification](https://docs.habana.ai/en/latest/Installation_Guide/SW_Verification.html#platform-upgrade)
@@ -120,12 +120,13 @@ docker run \
120120
- Inference with [HPU Graphs](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html)
121121
for accelerating low-batch latency and throughput
122122
- Attention with Linear Biases (ALiBi)
123+
- INC quantization
123124

124125
### Unsupported features
125126

126127
- Beam search
127128
- LoRA adapters
128-
- Quantization
129+
- AWQ quantization
129130
- Prefill chunking (mixed-batch inferencing)
130131

131132
### Supported configurations

vllm/config.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -963,7 +963,7 @@ def _verify_quantization(self) -> None:
963963
optimized_quantization_methods = [
964964
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
965965
"awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8",
966-
"quark", "modelopt_fp4", "bitblas", "gptq_bitblas"
966+
"quark", "modelopt_fp4", "bitblas", "gptq_bitblas", "inc"
967967
]
968968
if self.quantization is not None:
969969
self.quantization = cast(me_quant.QuantizationMethods,
@@ -1563,7 +1563,7 @@ def get_and_verify_max_len(self, max_model_len: int):
15631563

15641564

15651565
BlockSize = Literal[1, 8, 16, 32, 64, 128]
1566-
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"]
1566+
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
15671567
PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
15681568

15691569

@@ -1593,7 +1593,7 @@ class CacheConfig:
15931593
cache_dtype: CacheDType = "auto"
15941594
"""Data type for kv cache storage. If "auto", will use model data type.
15951595
CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
1596-
fp8 (=fp8_e4m3)."""
1596+
fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc)."""
15971597
is_attention_free: bool = False
15981598
"""Whether the model is attention-free. This is primarily set in
15991599
`ModelConfig` and that value should be manually duplicated here."""
@@ -1691,7 +1691,7 @@ def _verify_cache_dtype(self) -> None:
16911691
"Using fp8 data type to store kv cache. It reduces the GPU "
16921692
"memory footprint and boosts the performance. "
16931693
"Meanwhile, it may cause accuracy drop without a proper "
1694-
"scaling factor")
1694+
"scaling factor.")
16951695
else:
16961696
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
16971697

@@ -1781,6 +1781,9 @@ class LoadConfig:
17811781
default_factory=dict)
17821782
"""Extra config for model loader. This will be passed to the model loader
17831783
corresponding to the chosen load_format."""
1784+
device: Optional[str] = None
1785+
"""Device to which model weights will be loaded, default to
1786+
device_config.device"""
17841787
ignore_patterns: Optional[Union[list[str], str]] = None
17851788
"""The list of patterns to ignore when loading the model. Default to
17861789
"original/**/*" to avoid repeated loading of llama's checkpoints."""
@@ -1907,7 +1910,7 @@ class ParallelConfig:
19071910
or equal to the number of GPUs available, "mp" will be used to
19081911
keep processing on a single host. Otherwise, this will default
19091912
to "ray" if Ray is installed and fail otherwise. Note that tpu
1910-
and hpu only support Ray for distributed inference."""
1913+
only support Ray for distributed inference."""
19111914

19121915
worker_cls: str = "auto"
19131916
"""The full name of the worker class to use. If "auto", the worker class

vllm/engine/arg_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
139139
return type_hints
140140

141141

142+
def is_online_quantization(quantization: Any) -> bool:
143+
return quantization in ["inc"]
144+
145+
142146
@functools.lru_cache(maxsize=30)
143147
def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
144148
cls_docs = get_attr_docs(cls)
@@ -960,6 +964,8 @@ def create_load_config(self) -> LoadConfig:
960964
return LoadConfig(
961965
load_format=self.load_format,
962966
download_dir=self.download_dir,
967+
device="cpu"
968+
if is_online_quantization(self.quantization) else None,
963969
model_loader_extra_config=self.model_loader_extra_config,
964970
ignore_patterns=self.ignore_patterns,
965971
use_tqdm_on_load=self.use_tqdm_on_load,
@@ -1359,7 +1365,9 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
13591365
supported = False
13601366
if current_platform.is_rocm() or (
13611367
current_platform.is_cuda()
1362-
and current_platform.is_device_capability(100)):
1368+
and current_platform.is_device_capability(100)) or (
1369+
current_platform.device_name
1370+
== "hpu"): # handle hpu also for OOT platform
13631371
supported = True
13641372
elif fp8_attention and will_use_fa:
13651373
from vllm.attention.utils.fa_utils import (

vllm/model_executor/layers/quantization/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
"torchao",
3737
"auto-round",
3838
"rtn",
39+
"inc",
3940
]
4041
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
4142

@@ -104,6 +105,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
104105
from .gptq_marlin import GPTQMarlinConfig
105106
from .gptq_marlin_24 import GPTQMarlin24Config
106107
from .hqq_marlin import HQQMarlinConfig
108+
from .inc import INCConfig
107109
from .ipex_quant import IPEXConfig
108110
from .marlin import MarlinConfig
109111
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
@@ -144,7 +146,8 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
144146
"moe_wna16": MoeWNA16Config,
145147
"torchao": TorchAOConfig,
146148
"auto-round": AutoRoundConfig,
147-
"rtn": RTNConfig
149+
"rtn": RTNConfig,
150+
"inc": INCConfig,
148151
}
149152
# Update the `method_to_config` with customized quantization methods.
150153
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
@@ -157,4 +160,4 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
157160
"QuantizationMethods",
158161
"get_quantization_config",
159162
"QUANTIZATION_METHODS",
160-
]
163+
]
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
#
4+
# Intel Gaudi supports quantization of various modules and functions,
5+
# including, but not limited to `Linear`, `KVCache`, `Matmul` and `Softmax`.
6+
# During model loading,
7+
# INC will patch layers with quantization/dequantization operators.
8+
# Meanwhile, INC will convert original weight to target datatype
9+
# and loading to target device.
10+
# static scaling should be provided through Quant_CONFIG:
11+
# `QUANT_CONFIG` is an environment variable,
12+
# that points to the measurement or quantization JSON config file.
13+
# The measurement configuration file is used during the calibration procedure,
14+
# to collect measurements for a given model.
15+
# The quantization configuration is used during inference.
16+
# For more information, please refer to:
17+
# https://docs.habana.ai/en/v1.21.1/PyTorch/vLLM_Inference/vLLM_FP8_Inference.html
18+
19+
from typing import Any, Optional
20+
21+
import torch
22+
23+
from vllm.model_executor.layers.fused_moe.layer import (
24+
FusedMoE, UnquantizedFusedMoEMethod)
25+
from vllm.model_executor.layers.linear import (LinearBase,
26+
UnquantizedLinearMethod)
27+
from vllm.model_executor.layers.quantization import QuantizationMethods
28+
from vllm.model_executor.layers.quantization.base_config import (
29+
QuantizationConfig, QuantizeMethodBase)
30+
31+
32+
class INCConfig(QuantizationConfig):
33+
"""Config class for FP8 using Intel Neural Compressor."""
34+
35+
@classmethod
36+
def get_name(cls) -> QuantizationMethods:
37+
return "inc"
38+
39+
@classmethod
40+
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
41+
return [torch.bfloat16]
42+
43+
@classmethod
44+
def from_config(cls, config: dict[str, Any]) -> "INCConfig":
45+
raise AssertionError
46+
47+
def get_quant_method(self, layer: torch.nn.Module,
48+
prefix: str) -> Optional["QuantizeMethodBase"]:
49+
if isinstance(layer, LinearBase):
50+
return UnquantizedLinearMethod()
51+
elif isinstance(layer, FusedMoE):
52+
return UnquantizedFusedMoEMethod(layer.moe_config)
53+
return None
54+
55+
@classmethod
56+
def get_min_capability(cls) -> int:
57+
raise AssertionError
58+
59+
@staticmethod
60+
def get_config_filenames() -> list[str]:
61+
return []

vllm/model_executor/model_loader/base_loader.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66
import torch.nn as nn
77

88
from vllm.config import LoadConfig, ModelConfig, VllmConfig
9+
from vllm.logger import init_logger
910
from vllm.model_executor.model_loader.utils import (
1011
initialize_model, process_weights_after_loading, set_default_torch_dtype)
1112

13+
logger = init_logger(__name__)
14+
1215

1316
class BaseModelLoader(ABC):
1417
"""Base class for model loaders."""
@@ -32,11 +35,16 @@ def load_model(self, vllm_config: VllmConfig,
3235
model_config: ModelConfig) -> nn.Module:
3336
"""Load a model with the given configurations."""
3437
device_config = vllm_config.device_config
35-
target_device = torch.device(device_config.device)
38+
load_config = vllm_config.load_config
39+
load_device = device_config.device if load_config.device is None else \
40+
load_config.device
41+
target_device = torch.device(load_device)
3642
with set_default_torch_dtype(model_config.dtype):
3743
with target_device:
3844
model = initialize_model(vllm_config=vllm_config,
3945
model_config=model_config)
46+
47+
logger.debug("Loading weights on %s ...", load_device)
4048
# Quantization does not happen in `load_weights` but after it
4149
self.load_weights(model, model_config)
4250
process_weights_after_loading(model, model_config, target_device)

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ def get_quant_config(model_config: ModelConfig,
152152
quant_cls = get_quantization_config(model_config.quantization)
153153

154154
# GGUF doesn't have config file
155-
if model_config.quantization == "gguf":
156-
return quant_cls.from_config({})
155+
if model_config.quantization in ("gguf", "inc"):
156+
return quant_cls()
157157

158158
# Read the quantization config from the HF model config, if available.
159159
hf_quant_config = getattr(model_config.hf_config, "quantization_config",

0 commit comments

Comments
 (0)