Skip to content

Commit cad4375

Browse files
author
tangbinhan
committed
add loadtimequantization modelloader
1 parent 68b4755 commit cad4375

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1040
-455
lines changed

docs/get_started/ernie-4.5.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,21 @@ python -m fastdeploy.entrypoints.openai.api_server \
3030
--max-num-seqs 32
3131
```
3232

33+
To speed up model loading, set the environment variable **export FD_USE_FASTSAFETENSOR=1** and use the **--load_format "load_time_quantization"** option.
34+
35+
```shell
36+
export FD_USE_FASTSAFETENSOR=1
37+
python -m fastdeploy.entrypoints.openai.api_server \
38+
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
39+
--port 8180 --engine-worker-queue-port 8181 \
40+
--cache-queue-port 8182 --metrics-port 8182 \
41+
--tensor-parallel-size 8 \
42+
--quantization wint4 \
43+
--max-model-len 32768 \
44+
--max-num-seqs 32 \
45+
--load_format "load_time_quantization"
46+
```
47+
3348
## Request the Service
3449
After starting the service, the following output indicates successful initialization:
3550

docs/quantization/online_quantization.md

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
2424

2525
- By specifying `--model baidu/ERNIE-4.5-300B-A47B-Paddle`, the model can be automatically downloaded from AIStudio. FastDeploy depends on Paddle format models. For more information, please refer to [Supported Model List](../supported_models.md).
2626
- By setting `--quantization` to `wint8` or `wint4`, online INT8/INT4 quantization can be selected.
27-
- Deploying ERNIE-4.5-300B-A47B-Paddle WINT8 requires at least 80G * 8 cards, while WINT4 requires 80GB * 4 cards.
27+
- Deploying ERNIE-4.5-300B-A47B-Paddle WINT8 requires at least 80G *8 cards, while WINT4 requires 80GB* 4 cards.
2828
- For more deployment tutorials, please refer to [get_started](../get_started/ernie-4.5.md).
2929

3030
## 2. Block-wise FP8
@@ -51,4 +51,23 @@ python -m fastdeploy.entrypoints.openai.api_server \
5151
- By specifying `--model baidu/ERNIE-4.5-300B-A47B-Paddle`, the model can be automatically downloaded from AIStudio. FastDeploy depends on Paddle format models. For more information, please refer to [Supported Model List](../supported_models.md).
5252
- By setting `--quantization` to `block_wise_fp8`, online Block-wise FP8 quantization can be selected.
5353
- Deploying ERNIE-4.5-300B-A47B-Paddle Block-wise FP8 requires at least 80G * 8 cards.
54-
- For more deployment tutorials, please refer to [get_started](../get_started/ernie-4.5.md)
54+
- For more deployment tutorials, please refer to [get_started](../get_started/ernie-4.5.md)
55+
56+
# LoadTimeQuantization
57+
To speed up loading with FastSafeTensor and load large bfloat16 models onto the GPU, we shifted quantization to the weight loading stage and performed it dynamically. This supports quantization formats such as INT4, INT8, and FP8.
58+
59+
## 1. Run loadtimequant modelloader
60+
To speed up model loading, set the environment variable **export FD_USE_FASTSAFETENSOR=1** and use the **--load_format "load_time_quantization"** option.
61+
62+
```
63+
export FD_USE_FASTSAFETENSOR=1
64+
python -m fastdeploy.entrypoints.openai.api_server \
65+
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
66+
--port 8180 --engine-worker-queue-port 8181 \
67+
--cache-queue-port 8182 --metrics-port 8182 \
68+
--tensor-parallel-size 8 \
69+
--quantization wint8 \
70+
--max-model-len 32768 \
71+
--max-num-seqs 32\
72+
--load_format "load_time_quantization"
73+
```

docs/zh/get_started/ernie-4.5.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,21 @@ python -m fastdeploy.entrypoints.openai.api_server \
3131
--max-num-seqs 32
3232
```
3333

34+
可以通过设置环境变量 **export FD_USE_FASTSAFETENSOR=1** 并添加参数 **--load_format "load_time_quantization"**,提升权重load速度,
35+
36+
```shell
37+
export FD_USE_FASTSAFETENSOR=1
38+
python -m fastdeploy.entrypoints.openai.api_server \
39+
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
40+
--port 8180 --engine-worker-queue-port 8181 \
41+
--cache-queue-port 8182 --metrics-port 8182 \
42+
--tensor-parallel-size 8 \
43+
--quantization wint4 \
44+
--max-model-len 32768 \
45+
--max-num-seqs 32 \
46+
--load_format "load_time_quantization"
47+
```
48+
3449
## 用户发起服务请求
3550
执行启动服务指令后,当终端打印如下信息,说明服务已经启动成功。
3651

docs/zh/quantization/online_quantization.md

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ python -m fastdeploy.entrypoints.openai.api_server \
2323
```
2424

2525
- 通过指定 `--model baidu/ERNIE-4.5-300B-A47B-Paddle` 可自动从AIStudio下载模型。FastDeploy依赖Paddle格式的模型,更多说明参考[支持模型列表](../supported_models.md)
26-
- 通过设置 `--quantization``wint8``wint4` 选择在线 INT8/INT4 量化。
27-
- 部署 ERNIE-4.5-300B-A47B-Paddle WINT8 最少需要 80G * 8卡, WINT4 则需要 80GB * 4卡。
26+
- 通过设置 `--quantization``wint8``wint4` 选择在线 INT8/INT4 量化。
27+
- 部署 ERNIE-4.5-300B-A47B-Paddle WINT8 最少需要 80G *8卡, WINT4 则需要 80GB* 4卡。
2828
- 更多部署教程请参考[get_started](../get_started/ernie-4.5.md).
2929

3030
## 2. Block-wise FP8
@@ -49,9 +49,26 @@ python -m fastdeploy.entrypoints.openai.api_server \
4949
```
5050

5151
- 通过指定 `--model baidu/ERNIE-4.5-300B-A47B-Paddle` 可自动从AIStudio下载模型。FastDeploy依赖Paddle格式的模型,更多说明参考[支持模型列表](../supported_models.md)
52-
- 通过设置 `--quantization``block_wise_fp8` 选择在线 Block-wise FP8 量化。
52+
- 通过设置 `--quantization``block_wise_fp8` 选择在线 Block-wise FP8 量化。
5353
- 部署 ERNIE-4.5-300B-A47B-Paddle Block-wise FP8 最少需要 80G * 8卡。
5454
- 更多部署教程请参考[get_started](../get_started/ernie-4.5.md)
5555

56+
# LoadTimeQuantization
57+
为了使用fastsafeTensor提升load权重性能,并将300B的模型load进gpu,我们提供了一种新的modelloder(loadtimequantization),可以在load权重的同时进行动态量化,
58+
该modelloder支持INT4、INT8、FP8动态量化。
5659

60+
## 1. 使用 loadtimequantization modelloader
61+
你可以通过增加环境变量**export FD_USE_FASTSAFETENSOR=1** 开启fastsafetensor,并通过传参数**--load_format "load_time_quantization"** 开启加载权重时量化
5762

63+
```
64+
export FD_USE_FASTSAFETENSOR=1
65+
python -m fastdeploy.entrypoints.openai.api_server \
66+
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
67+
--port 8180 --engine-worker-queue-port 8181 \
68+
--cache-queue-port 8182 --metrics-port 8182 \
69+
--tensor-parallel-size 8 \
70+
--quantization wint8 \
71+
--max-model-len 32768 \
72+
--max-num-seqs 32\
73+
--load_format "load_time_quantization"
74+
```

fastdeploy/config.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from dataclasses import dataclass, field
2020
from enum import Enum
21-
from typing import Literal, Optional
21+
from typing import Any, Dict, Literal, Optional, Union
2222

2323
from paddleformers.transformers.configuration_utils import PretrainedConfig
2424

@@ -55,6 +55,7 @@ class ModelConfig(PretrainedConfig):
5555
frequency_score = 0.0
5656
presence_score = 0.0
5757
min_length = 1
58+
weight_infos_dict: Dict[str, Any] = {}
5859

5960
def __init__(
6061
self,
@@ -343,6 +344,12 @@ def __init__(self,
343344
self.graph_opt_level = 1
344345

345346

347+
class LoadFormat(str, Enum):
348+
"""LoadFormat"""
349+
DEFAULT = "default"
350+
LoadTimeQuant = "load_time_quantization"
351+
352+
346353
@dataclass
347354
class LoadConfig:
348355
"""
@@ -357,6 +364,7 @@ class LoadConfig:
357364
- 'meta': provide RL traing worker, no_weights_load
358365
- None: No dynamic loading
359366
"""
367+
load_format: Union[str, LoadFormat] = LoadFormat.DEFAULT.value
360368
use_fastsafetensor: bool = False
361369
dynamic_load_weight: bool = False
362370
load_strategy: Optional[Literal['ipc', 'ipc_no_reshard', 'ipc_snapshot', 'meta']] = None

fastdeploy/engine/args_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,14 @@ class EngineArgs:
293293
max_capture_batch_size=64, FastDeploy will capture graphs for batches [1,64].
294294
"""
295295

296+
load_format: str = "default"
297+
"""The format of the model weights to load.
298+
Options include:
299+
- "default": default loader.
300+
-"load_time_quantization": Quantization applied during model loading, \
301+
such as INT8, INT4, or FP8 formats.
302+
"""
303+
296304
def __post_init__(self):
297305
"""
298306
Post-initialization processing to set default tokenizer if not provided.
@@ -413,6 +421,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
413421
help=
414422
"Disabled any whitespaces when using guided decoding backend XGrammar."
415423
)
424+
# Load group
425+
load_group = parser.add_argument_group("Load Configuration")
426+
load_group.add_argument("--load_format",
427+
type=str,
428+
default=EngineArgs.load_format,
429+
help="The format of the model weights to load.\
430+
default/load_time_quantization.")
416431

417432
# Parallel processing parameters group
418433
parallel_group = parser.add_argument_group("Parallel Configuration")
@@ -784,4 +799,5 @@ def create_engine_config(self) -> Config:
784799
max_capture_batch_size=self.max_capture_batch_size,
785800
guided_decoding_backend=self.guided_decoding_backend,
786801
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
802+
load_format=self.load_format,
787803
)

fastdeploy/engine/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@ class Config:
494494
splitwise_role (str): Splitwise role.
495495
innode_prefill_ports (Optional[List[int]]): Innode prefill ports.
496496
Temporary configuration, will be removed in the future.
497+
load_format(str):The format of the model weights to load. .Default is default
497498
"""
498499

499500
def __init__(
@@ -526,6 +527,7 @@ def __init__(
526527
max_capture_batch_size: int = 64,
527528
guided_decoding_backend: Optional[str] = None,
528529
disable_any_whitespace: bool = False,
530+
load_format: str = "default",
529531
):
530532
"""
531533
Initialize the Config class.
@@ -554,6 +556,7 @@ def __init__(
554556
guided_decoding_backend(str): Guided decoding backend. Default is None.
555557
disable_any_whitespace(bool): Disable any whitespace when using guided decoding.
556558
Default is False.
559+
load_format(str):The format of the model weights to load. .Default is default
557560
"""
558561
self.model_config = model_config
559562
self.cache_config = cache_config
@@ -585,7 +588,8 @@ def __init__(
585588
self.is_master = True
586589
self._str_to_list("innode_prefill_ports", int)
587590
self._str_to_list("pod_ips", str)
588-
591+
self.load_format = load_format
592+
589593
if self.pod_ips is None:
590594
self.nnode = 1
591595
else:

fastdeploy/engine/engine.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -998,8 +998,8 @@ def _start_worker_service(self):
998998
py_script = os.path.join(current_dir_path, worker_path)
999999

10001000
ori_vocab_size = (
1001-
len(self.data_processor.tokenizer.sp_model)
1002-
if hasattr(self.data_processor.tokenizer, 'sp_model')
1001+
len(self.data_processor.tokenizer.sp_model)
1002+
if hasattr(self.data_processor.tokenizer, 'sp_model')
10031003
else len(self.data_processor.tokenizer.vocab)
10041004
)
10051005

@@ -1032,7 +1032,8 @@ def _start_worker_service(self):
10321032
f" --speculative_model_quantization {self.cfg.speculative_config.quantization}"
10331033
f" --max_capture_batch_size {self.cfg.max_capture_batch_size}"
10341034
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
1035-
f" --load_strategy {self.cfg.model_config.load_strategy}")
1035+
f" --load_strategy {self.cfg.model_config.load_strategy}"
1036+
f" --load_format {self.cfg.load_format}")
10361037

10371038
worker_append_flag = {
10381039
"enable_expert_parallel":

fastdeploy/model_executor/layers/activation.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,17 @@ def __init__(
7878
self.shift = shift
7979
self.smooth = smooth
8080
self.quant_scale = quant_scale
81-
self.quant_round_type = fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
82-
self.quant_max_bound = fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
83-
self.quant_min_bound = fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
81+
if fd_config.quant_config:
82+
self.quant_round_type = fd_config.quant_config.get_quant_method(
83+
self).quant_round_type
84+
self.quant_max_bound = fd_config.quant_config.get_quant_method(
85+
self).quant_max_bound
86+
self.quant_min_bound = fd_config.quant_config.get_quant_method(
87+
self).quant_min_bound
88+
else:
89+
self.quant_round_type = 0
90+
self.quant_max_bound = 0
91+
self.quant_min_bound = 0
8492

8593
self._dtype = self._helper.get_default_dtype()
8694
if self._dtype == "bfloat16":

fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,24 @@ def create_weights(self, layer: nn.Layer) -> None:
5353
is_bias=False,
5454
)
5555

56-
def process_loaded_weights(self, layer: nn.Layer,
57-
weight: paddle.Tensor) -> None:
56+
def process_quantized_weights(self, layer, state_dict) -> None:
57+
"""process_quantized_weights"""
58+
# (tangbinhan:todo) quant_utils support xpu
59+
layer.linear_weight.set_value(state_dict.pop(layer.weight_key))
60+
layer.linear_weight_scale.set_value(state_dict.pop(layer.weight_scale))
61+
62+
def apply_weight_quantization(self, weight):
63+
"""apply_weight_quantization"""
64+
quanted_weight_tensor, weight_scale_tensor = weight_quantize_xpu(
65+
weight, self.quant_config.algo, -1, -1)
66+
return quanted_weight_tensor, weight_scale_tensor
67+
68+
def process_unquantized_weights(self, layer, weight) -> None:
5869
"""
5970
loaded_weights using xpu special quantization
6071
"""
61-
quanted_weight_tensor, weight_scale_tensor = weight_quantize_xpu(
62-
weight, self.quant_config.algo, -1, -1)
72+
quanted_weight_tensor, weight_scale_tensor = self.apply_weight_quantization(
73+
weight)
6374
layer.linear_weight.set_value(
6475
paddle.transpose(quanted_weight_tensor, [1, 0]))
6576
layer.linear_weight_scale.set_value(weight_scale_tensor)

0 commit comments

Comments
 (0)