diff --git a/docs/get_started/ernie-4.5.md b/docs/get_started/ernie-4.5.md index fe36640a3b..40cc7cf34c 100644 --- a/docs/get_started/ernie-4.5.md +++ b/docs/get_started/ernie-4.5.md @@ -30,6 +30,21 @@ python -m fastdeploy.entrypoints.openai.api_server \ --max-num-seqs 32 ``` +To speed up model loading, set the environment variable **export FD_USE_FASTSAFETENSOR=1** and use the **--load_format "load_time_quantization"** option. + +```shell +export FD_USE_FASTSAFETENSOR=1 +python -m fastdeploy.entrypoints.openai.api_server \ + --model baidu/ERNIE-4.5-300B-A47B-Paddle \ + --port 8180 --engine-worker-queue-port 8181 \ + --cache-queue-port 8182 --metrics-port 8182 \ + --tensor-parallel-size 8 \ + --quantization wint4 \ + --max-model-len 32768 \ + --max-num-seqs 32 \ + --load_format "load_time_quantization" +``` + ## Request the Service After starting the service, the following output indicates successful initialization: diff --git a/docs/quantization/online_quantization.md b/docs/quantization/online_quantization.md index 3e3f24df90..a2522dbcb6 100644 --- a/docs/quantization/online_quantization.md +++ b/docs/quantization/online_quantization.md @@ -24,7 +24,7 @@ python -m fastdeploy.entrypoints.openai.api_server \ - By specifying `--model baidu/ERNIE-4.5-300B-A47B-Paddle`, the model can be automatically downloaded from AIStudio. FastDeploy depends on Paddle format models. For more information, please refer to [Supported Model List](../supported_models.md). - By setting `--quantization` to `wint8` or `wint4`, online INT8/INT4 quantization can be selected. -- Deploying ERNIE-4.5-300B-A47B-Paddle WINT8 requires at least 80G * 8 cards, while WINT4 requires 80GB * 4 cards. +- Deploying ERNIE-4.5-300B-A47B-Paddle WINT8 requires at least 80G *8 cards, while WINT4 requires 80GB* 4 cards. - For more deployment tutorials, please refer to [get_started](../get_started/ernie-4.5.md). ## 2. Block-wise FP8 @@ -51,4 +51,23 @@ python -m fastdeploy.entrypoints.openai.api_server \ - By specifying `--model baidu/ERNIE-4.5-300B-A47B-Paddle`, the model can be automatically downloaded from AIStudio. FastDeploy depends on Paddle format models. For more information, please refer to [Supported Model List](../supported_models.md). - By setting `--quantization` to `block_wise_fp8`, online Block-wise FP8 quantization can be selected. - Deploying ERNIE-4.5-300B-A47B-Paddle Block-wise FP8 requires at least 80G * 8 cards. -- For more deployment tutorials, please refer to [get_started](../get_started/ernie-4.5.md) +- For more deployment tutorials, please refer to [get_started](../get_started/ernie-4.5.md) + +# LoadTimeQuantization +To speed up loading with FastSafeTensor and load large bfloat16 models onto the GPU, we shifted quantization to the weight loading stage and performed it dynamically. This supports quantization formats such as INT4, INT8, and FP8. + +## 1. Run loadtimequant modelloader +To speed up model loading, set the environment variable **export FD_USE_FASTSAFETENSOR=1** and use the **--load_format "load_time_quantization"** option. + +``` +export FD_USE_FASTSAFETENSOR=1 +python -m fastdeploy.entrypoints.openai.api_server \ + --model baidu/ERNIE-4.5-300B-A47B-Paddle \ + --port 8180 --engine-worker-queue-port 8181 \ + --cache-queue-port 8182 --metrics-port 8182 \ + --tensor-parallel-size 8 \ + --quantization wint8 \ + --max-model-len 32768 \ + --max-num-seqs 32\ + --load_format "load_time_quantization" +``` diff --git a/docs/zh/get_started/ernie-4.5.md b/docs/zh/get_started/ernie-4.5.md index 4c8bc6ea01..5d7a1934d6 100644 --- a/docs/zh/get_started/ernie-4.5.md +++ b/docs/zh/get_started/ernie-4.5.md @@ -31,6 +31,21 @@ python -m fastdeploy.entrypoints.openai.api_server \ --max-num-seqs 32 ``` +可以通过设置环境变量 **export FD_USE_FASTSAFETENSOR=1** 并添加参数 **--load_format "load_time_quantization"**,提升权重load速度, + +```shell +export FD_USE_FASTSAFETENSOR=1 +python -m fastdeploy.entrypoints.openai.api_server \ + --model baidu/ERNIE-4.5-300B-A47B-Paddle \ + --port 8180 --engine-worker-queue-port 8181 \ + --cache-queue-port 8182 --metrics-port 8182 \ + --tensor-parallel-size 8 \ + --quantization wint4 \ + --max-model-len 32768 \ + --max-num-seqs 32 \ + --load_format "load_time_quantization" +``` + ## 用户发起服务请求 执行启动服务指令后,当终端打印如下信息,说明服务已经启动成功。 diff --git a/docs/zh/quantization/online_quantization.md b/docs/zh/quantization/online_quantization.md index f487f8ac87..fa8948e788 100644 --- a/docs/zh/quantization/online_quantization.md +++ b/docs/zh/quantization/online_quantization.md @@ -23,8 +23,8 @@ python -m fastdeploy.entrypoints.openai.api_server \ ``` - 通过指定 `--model baidu/ERNIE-4.5-300B-A47B-Paddle` 可自动从AIStudio下载模型。FastDeploy依赖Paddle格式的模型,更多说明参考[支持模型列表](../supported_models.md)。 -- 通过设置 `--quantization` 为 `wint8` 或 `wint4` 选择在线 INT8/INT4 量化。 -- 部署 ERNIE-4.5-300B-A47B-Paddle WINT8 最少需要 80G * 8卡, WINT4 则需要 80GB * 4卡。 +- 通过设置 `--quantization` 为 `wint8` 或 `wint4` 选择在线 INT8/INT4 量化。 +- 部署 ERNIE-4.5-300B-A47B-Paddle WINT8 最少需要 80G *8卡, WINT4 则需要 80GB* 4卡。 - 更多部署教程请参考[get_started](../get_started/ernie-4.5.md). ## 2. Block-wise FP8 @@ -49,9 +49,26 @@ python -m fastdeploy.entrypoints.openai.api_server \ ``` - 通过指定 `--model baidu/ERNIE-4.5-300B-A47B-Paddle` 可自动从AIStudio下载模型。FastDeploy依赖Paddle格式的模型,更多说明参考[支持模型列表](../supported_models.md)。 -- 通过设置 `--quantization` 为 `block_wise_fp8` 选择在线 Block-wise FP8 量化。 +- 通过设置 `--quantization` 为 `block_wise_fp8` 选择在线 Block-wise FP8 量化。 - 部署 ERNIE-4.5-300B-A47B-Paddle Block-wise FP8 最少需要 80G * 8卡。 - 更多部署教程请参考[get_started](../get_started/ernie-4.5.md) +# LoadTimeQuantization +为了使用fastsafeTensor提升load权重性能,并将300B的模型load进gpu,我们提供了一种新的modelloder(loadtimequantization),可以在load权重的同时进行动态量化, +该modelloder支持INT4、INT8、FP8动态量化。 +## 1. 使用 loadtimequantization modelloader +你可以通过增加环境变量**export FD_USE_FASTSAFETENSOR=1** 开启fastsafetensor,并通过传参数**--load_format "load_time_quantization"** 开启加载权重时量化 +``` +export FD_USE_FASTSAFETENSOR=1 +python -m fastdeploy.entrypoints.openai.api_server \ + --model baidu/ERNIE-4.5-300B-A47B-Paddle \ + --port 8180 --engine-worker-queue-port 8181 \ + --cache-queue-port 8182 --metrics-port 8182 \ + --tensor-parallel-size 8 \ + --quantization wint8 \ + --max-model-len 32768 \ + --max-num-seqs 32\ + --load_format "load_time_quantization" +``` diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 446e59298d..4d55930e1d 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -18,7 +18,7 @@ from dataclasses import dataclass, field from enum import Enum -from typing import Literal, Optional +from typing import Any, Dict, Literal, Optional, Union from paddleformers.transformers.configuration_utils import PretrainedConfig @@ -55,6 +55,7 @@ class ModelConfig(PretrainedConfig): frequency_score = 0.0 presence_score = 0.0 min_length = 1 + weight_infos_dict: Dict[str, Any] = {} def __init__( self, @@ -343,6 +344,12 @@ def __init__(self, self.graph_opt_level = 1 +class LoadFormat(str, Enum): + """LoadFormat""" + DEFAULT = "default" + LoadTimeQuant = "load_time_quantization" + + @dataclass class LoadConfig: """ @@ -357,6 +364,7 @@ class LoadConfig: - 'meta': provide RL traing worker, no_weights_load - None: No dynamic loading """ + load_format: Union[str, LoadFormat] = LoadFormat.DEFAULT.value use_fastsafetensor: bool = False dynamic_load_weight: bool = False load_strategy: Optional[Literal['ipc', 'ipc_no_reshard', 'ipc_snapshot', 'meta']] = None diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 2611214cf2..5eb3e77501 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -293,6 +293,14 @@ class EngineArgs: max_capture_batch_size=64, FastDeploy will capture graphs for batches [1,64]. """ + load_format: str = "default" + """The format of the model weights to load. + Options include: + - "default": default loader. + -"load_time_quantization": Quantization applied during model loading, \ + such as INT8, INT4, or FP8 formats. + """ + def __post_init__(self): """ Post-initialization processing to set default tokenizer if not provided. @@ -413,6 +421,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help= "Disabled any whitespaces when using guided decoding backend XGrammar." ) + # Load group + load_group = parser.add_argument_group("Load Configuration") + load_group.add_argument("--load_format", + type=str, + default=EngineArgs.load_format, + help="The format of the model weights to load.\ + default/load_time_quantization.") # Parallel processing parameters group parallel_group = parser.add_argument_group("Parallel Configuration") @@ -784,4 +799,5 @@ def create_engine_config(self) -> Config: max_capture_batch_size=self.max_capture_batch_size, guided_decoding_backend=self.guided_decoding_backend, disable_any_whitespace=self.guided_decoding_disable_any_whitespace, + load_format=self.load_format, ) diff --git a/fastdeploy/engine/config.py b/fastdeploy/engine/config.py index d92f7c2a90..c2726670e6 100644 --- a/fastdeploy/engine/config.py +++ b/fastdeploy/engine/config.py @@ -494,6 +494,7 @@ class Config: splitwise_role (str): Splitwise role. innode_prefill_ports (Optional[List[int]]): Innode prefill ports. Temporary configuration, will be removed in the future. + load_format(str):The format of the model weights to load. .Default is default """ def __init__( @@ -526,6 +527,7 @@ def __init__( max_capture_batch_size: int = 64, guided_decoding_backend: Optional[str] = None, disable_any_whitespace: bool = False, + load_format: str = "default", ): """ Initialize the Config class. @@ -554,6 +556,7 @@ def __init__( guided_decoding_backend(str): Guided decoding backend. Default is None. disable_any_whitespace(bool): Disable any whitespace when using guided decoding. Default is False. + load_format(str):The format of the model weights to load. .Default is default """ self.model_config = model_config self.cache_config = cache_config @@ -585,7 +588,8 @@ def __init__( self.is_master = True self._str_to_list("innode_prefill_ports", int) self._str_to_list("pod_ips", str) - + self.load_format = load_format + if self.pod_ips is None: self.nnode = 1 else: diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 5fca12f0b9..789c51eab6 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -998,8 +998,8 @@ def _start_worker_service(self): py_script = os.path.join(current_dir_path, worker_path) ori_vocab_size = ( - len(self.data_processor.tokenizer.sp_model) - if hasattr(self.data_processor.tokenizer, 'sp_model') + len(self.data_processor.tokenizer.sp_model) + if hasattr(self.data_processor.tokenizer, 'sp_model') else len(self.data_processor.tokenizer.vocab) ) @@ -1032,7 +1032,8 @@ def _start_worker_service(self): f" --speculative_model_quantization {self.cfg.speculative_config.quantization}" f" --max_capture_batch_size {self.cfg.max_capture_batch_size}" f" --guided_decoding_backend {self.cfg.guided_decoding_backend}" - f" --load_strategy {self.cfg.model_config.load_strategy}") + f" --load_strategy {self.cfg.model_config.load_strategy}" + f" --load_format {self.cfg.load_format}") worker_append_flag = { "enable_expert_parallel": diff --git a/fastdeploy/model_executor/layers/activation.py b/fastdeploy/model_executor/layers/activation.py index aa8ff7f2c1..ad5a83d61c 100644 --- a/fastdeploy/model_executor/layers/activation.py +++ b/fastdeploy/model_executor/layers/activation.py @@ -78,9 +78,17 @@ def __init__( self.shift = shift self.smooth = smooth self.quant_scale = quant_scale - self.quant_round_type = fd_config.quant_config.quant_round_type if fd_config.quant_config else 0 - self.quant_max_bound = fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0 - self.quant_min_bound = fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0 + if fd_config.quant_config: + self.quant_round_type = fd_config.quant_config.get_quant_method( + self).quant_round_type + self.quant_max_bound = fd_config.quant_config.get_quant_method( + self).quant_max_bound + self.quant_min_bound = fd_config.quant_config.get_quant_method( + self).quant_min_bound + else: + self.quant_round_type = 0 + self.quant_max_bound = 0 + self.quant_min_bound = 0 self._dtype = self._helper.get_default_dtype() if self._dtype == "bfloat16": diff --git a/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py b/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py index db264b07a2..e0c3431436 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py @@ -53,13 +53,24 @@ def create_weights(self, layer: nn.Layer) -> None: is_bias=False, ) - def process_loaded_weights(self, layer: nn.Layer, - weight: paddle.Tensor) -> None: + def process_quantized_weights(self, layer, state_dict) -> None: + """process_quantized_weights""" + # (tangbinhan:todo) quant_utils support xpu + layer.linear_weight.set_value(state_dict.pop(layer.weight_key)) + layer.linear_weight_scale.set_value(state_dict.pop(layer.weight_scale)) + + def apply_weight_quantization(self, weight): + """apply_weight_quantization""" + quanted_weight_tensor, weight_scale_tensor = weight_quantize_xpu( + weight, self.quant_config.algo, -1, -1) + return quanted_weight_tensor, weight_scale_tensor + + def process_unquantized_weights(self, layer, weight) -> None: """ loaded_weights using xpu special quantization """ - quanted_weight_tensor, weight_scale_tensor = weight_quantize_xpu( - weight, self.quant_config.algo, -1, -1) + quanted_weight_tensor, weight_scale_tensor = self.apply_weight_quantization( + weight) layer.linear_weight.set_value( paddle.transpose(quanted_weight_tensor, [1, 0])) layer.linear_weight_scale.set_value(weight_scale_tensor) diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index b8dc49e1b0..5852d5a239 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -114,16 +114,16 @@ def init_weight(self): self.linear_shift = None self.linear_smooth = None - def load_prequant_weight(self, state_dict: dict): + def load_quantized_weight(self, state_dict): """ Load the prequantized weight from the state dictionary. Args: state_dict (dict): A dictionary containing the prequantized weights and scales. """ - self.quant_method.process_prequanted_weights(self, state_dict) + self.quant_method.process_quantized_weights(self, state_dict) - def load_weight(self, state_dict: dict): + def load_unquantized_weight(self, state_dict): """ Load the weight from the state dictionary. @@ -133,7 +133,7 @@ def load_weight(self, state_dict: dict): weight_tensor = get_tensor(state_dict.pop(self.weight_key)) if self.fd_config.quant_config: - self.quant_method.process_loaded_weights(self, weight_tensor) + self.quant_method.process_unquantized_weights(self, weight_tensor) else: self.linear_weight.set_value(weight_tensor) @@ -148,10 +148,9 @@ def load_state_dict(self, state_dict: dict): self.state_dict = state_dict assert self.weight_key is not None, 'weight_key should not be None.' if self.fd_config.model_config.is_quantized: - self.load_prequant_weight(state_dict) + self.load_quantized_weight(state_dict) else: - self.load_weight(state_dict) - + self.load_unquantized_weight(state_dict) # bias if self.with_bias: bias_tensor = paddle.to_tensor( @@ -366,25 +365,31 @@ def load_state_dict(self, state_dict: dict): """ # weight assert self.weight_key is not None, 'weight_key should not be None.' - if self.weight_key in state_dict.keys(): - weight_tensor = get_tensor(state_dict.pop(self.weight_key)) + if self.fd_config.model_config.is_quantized: + state_dict[self.weight_key] = get_tensor( + state_dict.pop(self.weight_key)) + state_dict[self.weight_scale_key] = get_tensor( + state_dict.pop(self.weight_scale_key)) else: - gate_weight_key = self.weight_key.replace("up_gate_proj", - "gate_proj") - up_weight_key = self.weight_key.replace("up_gate_proj", "up_proj") - gate_tensor = get_tensor(state_dict.pop(gate_weight_key)) - up_tensor = get_tensor(state_dict.pop(up_weight_key)) - weight_tensor = paddle.concat([gate_tensor, up_tensor], axis=-1) + if self.weight_key in state_dict.keys(): + weight_tensor = get_tensor(state_dict.pop(self.weight_key)) + else: + gate_weight_key = self.weight_key.replace("up_gate_proj", + "gate_proj") + up_weight_key = self.weight_key.replace("up_gate_proj", "up_proj") + gate_tensor = get_tensor(state_dict.pop(gate_weight_key)) + up_tensor = get_tensor(state_dict.pop(up_weight_key)) + weight_tensor = paddle.concat([gate_tensor, up_tensor], axis=-1) - if self.with_bias: - gate_bias_key = self.bias_key.replace("up_gate_proj", - "gate_proj") - bias_tensor = get_tensor(state_dict.pop(gate_bias_key)).astype( - paddle.get_default_dtype()) + if self.with_bias: + gate_bias_key = self.bias_key.replace("up_gate_proj", + "gate_proj") + bias_tensor = get_tensor(state_dict.pop(gate_bias_key)).astype( + paddle.get_default_dtype()) - state_dict[self.bias_key] = bias_tensor + state_dict[self.bias_key] = bias_tensor - state_dict[self.weight_key] = weight_tensor + state_dict[self.weight_key] = weight_tensor super().load_state_dict(state_dict) @@ -421,7 +426,16 @@ def __init__(self, fd_config, prefix, with_bias=False, add_bias=True): with_bias=with_bias, add_bias=add_bias) - def load_weight(self, state_dict: dict): + def load_quantized_weight(self, state_dict): + """ + Load the prequantized weight from the state dictionary. + + Args: + state_dict (dict): A dictionary containing the prequantized weights and scales. + """ + self.quant_method.process_quantized_weights(self, state_dict) + + def load_unquantized_weight(self, state_dict): """ Load the weight from the state dictionary. @@ -447,7 +461,7 @@ def load_weight(self, state_dict: dict): weight_tensor = paddle.transpose(weight_tensor, perm=[1, 0]) if self.fd_config.quant_config: - self.quant_method.process_loaded_weights(self, weight_tensor) + self.quant_method.process_unquantized_weights(self, weight_tensor) else: self.linear_weight.set_value(weight_tensor) @@ -460,12 +474,10 @@ def load_state_dict(self, state_dict: dict): """ # weight assert self.weight_key is not None, 'weight_key should not be None.' - # qkv fused in disk - if self.fd_config.model_config.is_quantized: - self.load_prequant_weight(state_dict) + self.load_quantized_weight(state_dict) else: - self.load_weight(state_dict) + self.load_unquantized_weight(state_dict) # bias if self.with_bias: diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py index 3da7b783e4..3874579b4d 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py @@ -57,9 +57,9 @@ def init_ep(self, layer: nn.Layer) -> None: layer.top_k, layer.hidden_size, layer.num_experts, layer.ep_size, layer.ep_rank) - def process_loaded_weights(self, layer, weights) -> None: + def process_unquantized_weights(self, layer, weights) -> None: """ - process_loaded_weights + process_unquantized_weights """ pass diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 3c00ddfe44..77897ac189 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -300,6 +300,30 @@ def __init__(self, quant_config): self.moe_quant_type = "w4a8" self.pack_num = 2 + def reorder(self, weight_tensor): + """reorder""" + for i in range(len(weight_tensor)): + quant_weight, _ = self.apply_weight_quantization(weight_tensor[i]) + weight_tensor[i] = quant_weight + + def process_quantized_weights(self, layer: nn.Layer, state_dict) -> None: + """process_quantized_weights""" + ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict) + for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]): + weight_name = self.added_weight_attrs[idx] + self.reorder(weight_tensor) + weight_list = weight_tensor + quanted_weight = paddle.stack(weight_list, axis=0) + create_and_set_parameter(layer, weight_name, quanted_weight) + self.create_w4a8_scale_weights(layer, layer.weight_key_map, state_dict) + + def apply_weight_quantization(self, weight): + """apply_weight_quantization""" + quant_weight, scale = weight_quantize(weight, + algo=self.moe_quant_type, + arch=80) + return quant_weight, scale + def create_weights(self, layer: nn.Layer, state_dict): """ Paddle cutlass create weight process. @@ -310,9 +334,8 @@ def create_weights(self, layer: nn.Layer, state_dict): weight_name = self.added_weight_attrs[idx] weight_list = [] for i in range(layer.num_local_experts): - quant_weight, scale = weight_quantize(weight_tensor[i], - algo=self.moe_quant_type, - arch=80) + quant_weight, scale = self.apply_weight_quantization( + weight_tensor[i]) weight_list.append(quant_weight) quanted_weight = paddle.stack(weight_list, axis=0) create_and_set_parameter(layer, weight_name, quanted_weight) @@ -408,7 +431,7 @@ def __init__(self, quant_config): self.moe_quant_type = self.quant_config.algo self.pack_num = 1 - def process_prequanted_weights(self, layer: nn.Layer, state_dict): + def process_quantized_weights(self, layer: nn.Layer, state_dict): """ Paddle cutlass process prequanted weights. """ @@ -437,19 +460,30 @@ def process_prequanted_weights(self, layer: nn.Layer, state_dict): state_dict.pop( ffn2_expert_weight_scale_key.format(expert_idx)))) - ffn1_weight = paddle.stack(ffn1_weights, axis=0) - ffn2_weight = paddle.stack(ffn2_weights, axis=0) - ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0) - ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0) - name_tensor_map = { - "moe_ffn1_weight": ffn1_weight, - "moe_ffn2_weight": ffn2_weight, + "moe_ffn1_weight": ffn1_weights, + "moe_ffn2_weight": ffn2_weights, "moe_ffn1_weight_scale": ffn1_weight_scale, "moe_ffn2_weight_scale": ffn2_weight_scale } - for name, tensor in name_tensor_map.items(): - create_and_set_parameter(layer, name, tensor) + + for name, tensor_list in name_tensor_map.items(): + setattr( + layer, name, + layer.create_parameter( + shape=[len(tensor_list)] + list(tensor_list[0].shape), + dtype=tensor_list[0].dtype, + default_initializer=paddle.nn.initializer.Constant(0), + )) + total_len = len(tensor_list) + for idx in range(total_len): + t = tensor_list.pop(0) + getattr(layer, name)[idx, ...].set_value(t) + + def apply_weight_quantization(self, weight): + """apply_weight_quantization""" + quant_weight, scale = weight_quantize(weight, algo=self.moe_quant_type) + return quant_weight, scale def create_weights(self, layer: nn.Layer, state_dict): """ @@ -465,8 +499,8 @@ def create_weights(self, layer: nn.Layer, state_dict): weight_list = [] weight_scale_list = [] for i in range(layer.num_local_experts): - quant_weight, scale = weight_quantize(weight_tensor[i], - algo=self.moe_quant_type) + quant_weight, scale = self.apply_weight_quantization( + weight_tensor[i]) weight_list.append(quant_weight) weight_scale_list.append(scale) quanted_weight = paddle.stack(weight_list, axis=0) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index c3bb8d3f1d..9f5b6d9206 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -14,7 +14,6 @@ # limitations under the License. """ -import numpy as np import paddle from paddle import nn from paddleformers.utils.log import logger @@ -23,8 +22,8 @@ import fastdeploy.model_executor.ops.gpu.deep_gemm as deep_gemm from fastdeploy.distributed.communication_op import \ tensor_model_parallel_all_reduce -from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func from fastdeploy.model_executor.layers.utils import get_tensor +from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func from ..utils import create_and_set_parameter from .fused_moe_backend_base import MoEMethodBase @@ -35,6 +34,14 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): DeepGemmFusedMoeMethod is a class that implements the MoEMethodBase interface for DeepGemm backend. """ + def apply_weight_quantization(self, weight): + """apply_weight_quantization""" + from fastdeploy.model_executor.layers.utils import \ + per_block_cast_to_fp8 + quant_weight, scale = per_block_cast_to_fp8( + weight, self.quant_config.weight_block_size) + return quant_weight, scale + def create_weights(self, layer: nn.Layer, state_dict): """ deepgemm create weight process. @@ -51,10 +58,8 @@ def create_weights(self, layer: nn.Layer, state_dict): weight_list = [] weight_scale_list = [] for i in range(layer.num_local_experts): - from fastdeploy.model_executor.layers.utils import \ - per_block_cast_to_fp8 - quant_weight, scale = per_block_cast_to_fp8( - weight_tensor[i], self.quant_config.weight_block_size) + quant_weight, scale = self.apply_weight_quantization( + weight_tensor[i]) weight_list.append(quant_weight) weight_scale_list.append(scale) @@ -67,9 +72,9 @@ def create_weights(self, layer: nn.Layer, state_dict): [0, 2, 1]).contiguous() create_and_set_parameter(layer, scale_name, quanted_weight_scale) - def process_prequanted_weights(self, layer: nn.Layer, state_dict): + def process_quantized_weights(self, layer: nn.Layer, state_dict): """ - Paddle cutlass process prequanted weights. + Paddle cutlass process quantized weights. """ ffn1_expert_weight_key = layer.weight_key_map.get( "ffn1_expert_weight_key", None) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 267dab451f..3979d17472 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -49,8 +49,8 @@ def __init__(self, quant_config=None): "moe_ffn1_weight_scale", "moe_ffn2_weight_scale" ] - def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None: - """process_prequanted_weights""" + def process_quantized_weights(self, layer: nn.Layer, state_dict) -> None: + """process_quantized_weights""" pass def create_weights(self, layer: nn.Layer, state_dict): @@ -261,8 +261,25 @@ def __init__(self, quant_method=None): """ self.quant_method = quant_method - def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None: - """process_prequanted_weights""" + def create_tp_dict(self, model_config, layer): + """create_tp_dict""" + from fastdeploy.model_executor.models.tp_utils import \ + TensorSplitMode as tsm + from fastdeploy.model_executor.models.utils import WeightMeta + ffn1_expert_weight_key = layer.weight_key_map.get( + "ffn1_expert_weight_key", None) + ffn2_expert_weight_key = layer.weight_key_map.get( + "ffn2_expert_weight_key", None) + for i in range(layer.num_experts): + ff1_key = ffn1_expert_weight_key.format(i) + ff1_weight_meta = WeightMeta(ff1_key, True, tsm.PairFused) + ff2_key = ffn2_expert_weight_key.format(i) + ff2_weight_meta = WeightMeta(ff2_key, False) + model_config.weight_infos_dict[ff1_key] = ff1_weight_meta + model_config.weight_infos_dict[ff2_key] = ff2_weight_meta + + def process_quantized_weights(self, layer: nn.Layer, state_dict) -> None: + """process_quantized_weights""" ffn1_tensor, ffn2_tensor = layer.extract_moe_ffn_weights(state_dict) assert ffn1_tensor[0].shape == [ diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py index 99e156d617..6a39d2c39d 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py @@ -34,9 +34,9 @@ def __init__(self, quant_config): super().__init__() self.moe_quant_type = quant_config.moe_quant_type - def process_loaded_weights(self, layer, weights) -> None: + def process_unquantized_weights(self, layer, weights) -> None: """ - process_loaded_weights + process_unquantized_weights """ pass @@ -67,13 +67,13 @@ def __init__(self, quant_config): super().__init__(quant_config) self.moe_quant_type = quant_config.moe_quant_type - def process_loaded_weights(self, layer, weights) -> None: + def process_unquantized_weights(self, layer, weights) -> None: """ - process_loaded_weights + process_unquantized_weights """ pass - def process_prequanted_weights(self, layer: nn.Layer, state_dict): + def process_quantized_weights(self, layer: nn.Layer, state_dict): """ Paddle cutlass process prequanted weights. """ diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index a14b4e2cca..3ddf0dae2b 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -92,6 +92,9 @@ def __init__( if moe_quant_config: self.quant_method = moe_quant_config.get_quant_method(self) self.moe_quant_type = moe_quant_config.name() + if hasattr(self.quant_method, 'create_tp_dict') and callable( + getattr(self.quant_method, 'create_tp_dict')): + self.quant_method.create_tp_dict(fd_config.model_config, self) else: # now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future from .fused_moe_cutlass_backend import CutlassMoEMethod @@ -289,7 +292,7 @@ def load_state_dict(self, state_dict): self.gate_weight.set_value(gate_weight_tensor.astype("float32")) if self.fd_config.model_config.is_quantized: - self.quant_method.process_prequanted_weights(self, state_dict) + self.quant_method.process_quantized_weights(self, state_dict) else: self.quant_method.create_weights(self, state_dict) diff --git a/fastdeploy/model_executor/layers/normalization.py b/fastdeploy/model_executor/layers/normalization.py index 6d25df3454..4dbcfbba49 100644 --- a/fastdeploy/model_executor/layers/normalization.py +++ b/fastdeploy/model_executor/layers/normalization.py @@ -73,9 +73,18 @@ def __init__( self.quant_scale: Optional[float] = quant_scale self._dtype: str = self._helper.get_default_dtype() self._norm_weight_dtype: str = self._dtype - self.quant_round_type: int = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0 - self.quant_max_bound: int = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0 - self.quant_min_bound: int = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0 + if fd_config.quant_config: + self.quant_round_type: int = fd_config.quant_config.get_quant_method( + self).quant_round_type + self.quant_max_bound: int = fd_config.quant_config.get_quant_method( + self).quant_max_bound + self.quant_min_bound: int = fd_config.quant_config.get_quant_method( + self).quant_min_bound + else: + self.quant_round_type: int = 0 + self.quant_max_bound: int = 0 + self.quant_min_bound: int = 0 + self.begin_norm_axis: int = begin_norm_axis self.init_weight() @@ -197,9 +206,17 @@ def __init__( self._dtype: str = self._helper.get_default_dtype() self._norm_weight_dtype: str = "float32" - self.quant_round_type: int = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0 - self.quant_max_bound: int = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0 - self.quant_min_bound: int = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0 + if fd_config.quant_config: + self.quant_round_type: int = fd_config.quant_config.get_quant_method( + self).quant_round_type + self.quant_max_bound: int = fd_config.quant_config.get_quant_method( + self).quant_max_bound + self.quant_min_bound: int = fd_config.quant_config.get_quant_method( + self).quant_min_bound + else: + self.quant_round_type: int = 0 + self.quant_max_bound: int = 0 + self.quant_min_bound: int = 0 self.init_weight() diff --git a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py index dea8c703b8..95065b8fb6 100644 --- a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py +++ b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py @@ -20,7 +20,7 @@ import fastdeploy from fastdeploy.model_executor.layers.moe import FusedMoE -from ..utils import per_block_cast_to_fp8, get_tensor +from ..utils import get_tensor, per_block_cast_to_fp8 from .quant_base import QuantConfigBase, QuantMethodBase @@ -34,9 +34,6 @@ class BlockWiseFP8Config(QuantConfigBase): def __init__(self, weight_block_size: list = [-1, -1]) -> None: super().__init__() self.weight_block_size = weight_block_size - self.quant_max_bound = 448 - self.quant_min_bound = -448 - self.quant_round_type = 1 def name(self) -> str: return "block_wise_fp8" @@ -47,9 +44,9 @@ def from_config(cls, config: dict) -> "BlockWiseFP8Config": return cls(weight_block_size) def get_quant_method(self, layer) -> Optional[QuantMethodBase]: - ''' + """ Get quantization method. - ''' + """ if isinstance(layer, FusedMoE): from fastdeploy.model_executor.layers.moe.fused_moe_deepgemm_backend import \ DeepGemmFusedMoeMethod @@ -69,6 +66,9 @@ def __init__( ) -> None: super().__init__() self.quant_config = quant_config + self.quant_max_bound = 448 + self.quant_min_bound = -448 + self.quant_round_type = 1 def create_weights(self, layer): layer.linear_weight_shape.reverse() @@ -84,16 +84,23 @@ def create_weights(self, layer): ) layer.weight_dtype = "float8_e4m3fn" - def process_loaded_weights(self, layer, weights) -> None: + def apply_weight_quantization(self, weight_tensor): + """apply_weight_quantization""" + quanted_weight_tensor, weight_block_scale_tensor = per_block_cast_to_fp8( + weight_tensor) + return quanted_weight_tensor, weight_block_scale_tensor + + def process_unquantized_weights(self, layer, weights) -> None: + """process_unquantized_weights""" weight_tensor = weights.transpose([1, 0]) quanted_weight_tensor, weight_block_scale_tensor = ( - per_block_cast_to_fp8(weight_tensor)) + self.apply_weight_quantization(weight_tensor)) layer.linear_weight.copy_(quanted_weight_tensor, False) layer.linear_weight_scale.set_value(weight_block_scale_tensor) - def process_prequanted_weights(self, layer, state_dict): + def process_quantized_weights(self, layer, state_dict): """ - process_prequanted_weights + process_quantized_weights """ quant_weight = get_tensor(state_dict.pop(layer.weight_key)) weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key)) diff --git a/fastdeploy/model_executor/layers/quantization/mix_quant.py b/fastdeploy/model_executor/layers/quantization/mix_quant.py index 4868b346bf..73db1c8049 100644 --- a/fastdeploy/model_executor/layers/quantization/mix_quant.py +++ b/fastdeploy/model_executor/layers/quantization/mix_quant.py @@ -29,42 +29,36 @@ class MixQuantConfig(QuantConfigBase): def __init__( self, - dense_quant_type: str, - moe_quant_type: str, - kv_cache_quant_type: str = None, - image_moe_quant_type: str = None, + config, ) -> None: super().__init__() - self.dense_quant_type = dense_quant_type - self.moe_quant_type = moe_quant_type - self.kv_cache_quant_type = kv_cache_quant_type + self.dense_quant_type = config["dense_quant_type"] + self.moe_quant_type = config["moe_quant_type"] + image_moe_quant_type = config.get("image_moe_quant_type", None) + self.kv_cache_quant_type = config.get("kv_cache_quant_type", None) if image_moe_quant_type is None: - self.image_moe_quant_type = moe_quant_type + self.image_moe_quant_type = self.moe_quant_type else: self.image_moe_quant_type = image_moe_quant_type - self.quant_max_bound = 0 - self.quant_min_bound = 0 - self.quant_round_type = 0 + self.config = config def name(self) -> str: return "mix_quant" @classmethod def from_config(cls, config: dict) -> "MixQuantConfig": - return cls(config['dense_quant_type'], config['moe_quant_type'], - config.get('kv_cache_quant_type', None), - config.get('image_moe_quant_type', None)) + return cls(config) def get_quant_method(self, layer) -> Optional[QuantMethodBase]: if isinstance(layer, FusedMoE): if layer.moe_tag == "Image": - return get_quantization_config( + return (get_quantization_config( self.image_moe_quant_type).from_config( - {}).get_quant_method(layer) + self.config).get_quant_method(layer)) else: - return get_quantization_config( + return (get_quantization_config( self.moe_quant_type).from_config( - {}).get_quant_method(layer) + self.config).get_quant_method(layer)) elif isinstance(layer, Attention): if self.kv_cache_quant_type is not None: return (get_quantization_config("kvcache").from_config( @@ -72,5 +66,5 @@ def get_quant_method(self, layer) -> Optional[QuantMethodBase]: else: return None else: - return get_quantization_config(self.dense_quant_type).from_config( - {}).get_quant_method(layer) + return (get_quantization_config(self.dense_quant_type).from_config( + self.config).get_quant_method(layer)) diff --git a/fastdeploy/model_executor/layers/quantization/quant_base.py b/fastdeploy/model_executor/layers/quantization/quant_base.py index 40df4aaf92..e90b8bcfb2 100644 --- a/fastdeploy/model_executor/layers/quantization/quant_base.py +++ b/fastdeploy/model_executor/layers/quantization/quant_base.py @@ -34,8 +34,19 @@ def apply(self, layer, *args, **kwargs): Expects create_weights to have been called before on the layer.""" raise NotImplementedError - def process_loaded_weights(self, layer, weights): - """Process the weight after loading. + def apply_weight_quantization(self, weight): + """Apply the weight quantization. + """ + return + + def process_quantized_weights(self, layer, state_dict): + """ + process_quantized_weights + """ + return + + def process_unquantized_weights(self, layer, weights): + """Process the nonquant weight after loading. This can be used for example, to transpose weights for computation. """ diff --git a/fastdeploy/model_executor/layers/quantization/tensor_wise_fp8.py b/fastdeploy/model_executor/layers/quantization/tensor_wise_fp8.py index 06992954c0..9ab4d40f1e 100644 --- a/fastdeploy/model_executor/layers/quantization/tensor_wise_fp8.py +++ b/fastdeploy/model_executor/layers/quantization/tensor_wise_fp8.py @@ -16,6 +16,7 @@ from typing import Optional import paddle +from paddleformers.utils.log import logger from fastdeploy.model_executor.layers.moe import FusedMoE @@ -78,13 +79,36 @@ def __init__( self.quant_round_type = 1 self.weight_dtype = "float8_e4m3fn" + def create_tp_dict(self, model_config, layer): + """create_tp_dict""" + from fastdeploy.model_executor.models.tp_utils import \ + TensorSplitMode as tsm + from fastdeploy.model_executor.models.utils import WeightMeta + + weight_key = layer.weight_key + if weight_key not in model_config.weight_infos_dict: + if ("up_gate_proj" in weight_key + or "shared_experts.up_gate_proj" in weight_key): + weight_meta = WeightMeta(weight_key, True, tsm.PairFused) + elif "qkv_proj" in weight_key: + weight_meta = WeightMeta(weight_key, True, tsm.GQA) + elif ("down_proj" in weight_key or "o_proj" in weight_key + or "shared_experts.down_proj" in weight_key): + weight_meta = WeightMeta(weight_key, False) + else: + logger.error(f"{weight_key} was not split.") + model_config.weight_infos_dict[weight_key] = weight_meta + def create_weights(self, layer): """ - Nothing to do! + create_weights """ - pass + if (not layer.fd_config.parallel_config.use_ep + and layer.fd_config.model_config.is_quantized and + layer.fd_config.parallel_config.tensor_parallel_degree > 1): + self.create_tp_dict(layer.fd_config.model_config, layer) - def process_prequanted_weights(self, layer, state_dict) -> None: + def process_quantized_weights(self, layer, state_dict) -> None: """ Process pre-quantized weights before applying them to the model Args: @@ -131,5 +155,6 @@ def apply(self, layer, x): bias=None, scale=self.total_scale, output_dtype="bfloat16", - activation_type="identity") + activation_type="identity", + ) return linear_out diff --git a/fastdeploy/model_executor/layers/quantization/w4afp8.py b/fastdeploy/model_executor/layers/quantization/w4afp8.py index 49453c5530..9cec3d0697 100644 --- a/fastdeploy/model_executor/layers/quantization/w4afp8.py +++ b/fastdeploy/model_executor/layers/quantization/w4afp8.py @@ -33,9 +33,6 @@ def __init__(self, weight_scale_dict, act_scale_dict) -> None: super().__init__() self.weight_scale_dict = weight_scale_dict self.act_scale_dict = act_scale_dict - self.quant_max_bound = 448 - self.quant_min_bound = -448 - self.quant_round_type = 1 def name(self) -> str: return "w4afp8" @@ -61,6 +58,9 @@ def __init__( ) -> None: super().__init__() self.quant_config = quant_config + self.quant_max_bound = 448 + self.quant_min_bound = -448 + self.quant_round_type = 1 def create_weights(self, layer): layer.linear_weight_shape.reverse() @@ -68,14 +68,26 @@ def create_weights(self, layer): layer.weight_dtype = "int8" pass - def process_loaded_weights(self, layer, weights) -> None: + def process_quantized_weights(self, layer, state_dict) -> None: + """process_quantized_weights""" + layer.linear_weight.set_value(state_dict.pop(layer.weight_key)) + layer.linear_weight_scale.set_value(state_dict.pop(layer.weight_scale)) + + def apply_weight_quantization(self, weight): + """apply_weight_quantization""" quanted_weight_tensor, weight_scale_tensor = ( fastdeploy.model_executor.ops.gpu. scaled_gemm_f8_i4_f16_weight_quantize( - paddle.cast(weights, "float32").cpu(), + paddle.cast(weight, "float32").cpu(), groupsize=-1, scale_dtype="float16", )) + return quanted_weight_tensor, weight_scale_tensor + + def process_unquantized_weights(self, layer, weights) -> None: + """process_unquantized_weights""" + quanted_weight_tensor, weight_scale_tensor = self.apply_weight_quantization( + weights) weight_scale_tensor = paddle.view(weight_scale_tensor, layer._dtype) layer.linear_weight.set_value(quanted_weight_tensor) layer.linear_weight_scale.set_value(weight_scale_tensor) diff --git a/fastdeploy/model_executor/layers/quantization/w8a8.py b/fastdeploy/model_executor/layers/quantization/w8a8.py index 8454210180..e1dbb73a09 100644 --- a/fastdeploy/model_executor/layers/quantization/w8a8.py +++ b/fastdeploy/model_executor/layers/quantization/w8a8.py @@ -37,9 +37,6 @@ def __init__(self, weight_scale_dict, act_scale_dict, use_gemm_dequant, self.act_scale_dict = act_scale_dict self.use_gemm_dequant = use_gemm_dequant self.use_smooth_quant = use_smooth_quant - self.quant_max_bound = 127 - self.quant_min_bound = -127 - self.quant_round_type = 0 def name(self) -> str: return "w8a8" @@ -67,6 +64,10 @@ def __init__( super().__init__() self.quant_config = quant_config self.smooth_quant_method = SmoothQuantLinearMethod(quant_config) + self.quant_max_bound = 127 + self.quant_min_bound = -127 + self.quant_round_type = 0 + self.weight_dtype = "int8" def create_weights(self, layer): layer.linear_weight_shape.reverse() @@ -95,16 +96,33 @@ def create_weights(self, layer): layer.linear_out_scale.set_value( convert_to_npu_dequant_scale(linear_out_scale)) - def process_loaded_weights(self, layer, weights) -> None: + def process_quantized_weights(self, layer, state_dict) -> None: + """process_quantized_weights""" + if self.quant_config.use_smooth_quant: + self.smooth_quant_method.process_unquantized_weights( + layer, state_dict[self.weight_key]) + layer.linear_weight.set_value(state_dict.pop(layer.weight_key)) + layer.linear_weight_scale.set_value( + state_dict.pop(layer.weight_scale_key)) + + def apply_weight_quantization(self, weights): + """apply_weight_quantization""" + weight_tensor = weights.transpose([1, 0]) + weight_tensor = paddle.cast(weight_tensor, + self.quant_config.weight_dtype) + return weight_tensor, None + + def process_unquantized_weights(self, layer, weights) -> None: + """process_unquantized_weights""" if self.quant_config.use_smooth_quant: - self.smooth_quant_method.process_loaded_weights(layer, weights) + self.smooth_quant_method.process_unquantized_weights( + layer, weights) if self.skip_quant: logger.debug(f"{layer.prefix} skip quant") weight_tensor = weights.cast(layer._dtype) layer.linear_weight.set_value(weight_tensor) else: - weight_tensor = weights.transpose([1, 0]) - weight_tensor = paddle.cast(weight_tensor, "int8") + weight_tensor, _ = self.apply_weight_quantization(weight_tensor) layer.linear_weight.set_value(weight_tensor) def apply(self, layer, x): @@ -147,7 +165,8 @@ def create_weights(self, layer): is_bias=False, ) - def process_loaded_weights(self, layer, weights) -> None: + def process_unquantized_weights(self, layer, weights) -> None: + """process_unquantized_weights""" if layer.shift_key in layer.state_dict: shift_tensor = get_tensor(layer.state_dict.pop( layer.shift_key)).astype(paddle.get_default_dtype()) diff --git a/fastdeploy/model_executor/layers/quantization/weight_only.py b/fastdeploy/model_executor/layers/quantization/weight_only.py index 9e890853b2..0f2a662115 100644 --- a/fastdeploy/model_executor/layers/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/quantization/weight_only.py @@ -46,9 +46,6 @@ def __init__( "FLAGS_weight_only_linear_arch") if self.weight_only_linear_arch is not None: self.weight_only_linear_arch = int(self.weight_only_linear_arch) - self.quant_max_bound = 0 - self.quant_min_bound = 0 - self.quant_round_type = 0 def name(self) -> str: return "weight_only" @@ -130,6 +127,9 @@ def __init__( ) -> None: super().__init__() self.quant_config = quant_config + self.quant_max_bound = 0 + self.quant_min_bound = 0 + self.quant_round_type = 0 def create_weights(self, layer): @@ -147,7 +147,8 @@ def create_weights(self, layer): ) @abstractmethod - def process_loaded_weights(self, layer, weights) -> None: + def process_unquantized_weights(self, layer, weights) -> None: + """process_unquantized_weights""" raise NotImplementedError def apply(self, layer, x): @@ -176,7 +177,14 @@ def __init__( ) -> None: super().__init__(quant_config) - def process_prequanted_weights(self, layer, state_dict) -> None: + def reorder(self, weight) -> paddle.Tensor: + """ + reorder + """ + quant_weight, _ = self.apply_weight_quantization(weight) + return quant_weight + + def process_quantized_weights(self, layer, state_dict) -> None: """ Process pre-quantized weights before applying them to the model Args: @@ -186,18 +194,29 @@ def process_prequanted_weights(self, layer, state_dict) -> None: """ quant_weight = get_tensor(state_dict.pop(layer.weight_key)) weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key)) - layer.linear_weight.set_value(quant_weight) - layer.linear_weight_scale.set_value( - weight_scale.astype(paddle.get_default_dtype())) - - def process_loaded_weights(self, layer, weight) -> None: + if layer.fd_config.quant_config.moe_quant_type == "w4a8": + quant_weight = self.reorder(quant_weight) + layer.linear_weight.set_value(quant_weight) + layer.linear_weight_scale.set_value( + weight_scale.astype(paddle.get_default_dtype()) / 127.0) + else: + layer.linear_weight.set_value(quant_weight) + layer.linear_weight_scale.set_value( + weight_scale.astype(paddle.get_default_dtype())) + def apply_weight_quantization(self, weight): + """apply_weight_quantization""" quanted_weight_tensor, weight_scale_tensor = weight_quantize( weight, algo=self.quant_config.algo, arch=self.quant_config.weight_only_linear_arch, ) + return quanted_weight_tensor, weight_scale_tensor + def process_unquantized_weights(self, layer, weights) -> None: + """process_unquantized_weights""" + quanted_weight_tensor, weight_scale_tensor = self.apply_weight_quantization( + weights) layer.linear_weight.set_value(quanted_weight_tensor) layer.linear_weight_scale.set_value( weight_scale_tensor.astype(paddle.get_default_dtype())) diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index c8ba1f673b..85cbe454b3 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -16,7 +16,9 @@ import json import os +from typing import Dict, Generator, Union +import numpy as np import paddle import paddle.distributed as dist from fastsafetensors import SafeTensorsFileLoader, SingleGroup @@ -49,15 +51,15 @@ def load_ep_checkpoint(model_path: str, config.num_experts_start_offset + config.num_experts_per_rank, ): ffn1_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight" - ffn2_key = (f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight") + ffn2_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight" - ffn1_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.quant_weight" - ffn2_quant_key = ( - f"ernie.layers.{i}.mlp.experts.{j}.down_proj.quant_weight") + ffn1_quant_key = ( + f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.quant_weight") + ffn2_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.quant_weight" - ffn1_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight_scale" - ffn2_scale_key = ( - f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight_scale") + ffn1_scale_key = ( + f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight_scale") + ffn2_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight_scale" num_local_ffn_keys.append(ffn1_key) num_local_ffn_keys.append(ffn2_key) num_local_ffn_keys.append(ffn1_quant_key) @@ -75,7 +77,7 @@ def load_ep_checkpoint(model_path: str, # Open each safetensor file sequentially with progress bar for safetensor_path in tqdm(safetensor_paths, - desc="Loading safetensor files", + desc="Loading safetensors checkpoint shards", unit="file"): with safe_open(os.path.join(model_path, safetensor_path), framework="np", @@ -92,7 +94,9 @@ def load_ep_checkpoint(model_path: str, return state_dict -def safetensors_weights_iterator(safe_tensor_list: list[str], ): +def safetensors_weights_iterator( + safe_tensor_list: list[str] +) -> Generator[tuple[str, np.ndarray], None, None]: """ safetensors_weights_iterator """ @@ -106,7 +110,9 @@ def safetensors_weights_iterator(safe_tensor_list: list[str], ): yield name, param -def fastsafetensors_weights_iterator(safetensor_list: list[str], ): +def fastsafetensors_weights_iterator( + safetensor_list: list[str] +) -> Generator[tuple[str, paddle.Tensor], None, None]: """ Return an iterator over tensors on GPU from a given safetensor_list. """ @@ -187,18 +193,16 @@ def get_all_safetensors(model_path: str): return key_name_list, safetensor_list -def load_tp_checkpoint_v1( - model_path: str, +def tp_weights_iterator( cls: PretrainedModel, fd_config: FDConfig, + safetensor_keys: list[str], + safetensor_files: list[str], use_fastsafetensor: bool = True, -): +) -> Generator[tuple[str, Union[paddle.Tensor, np.ndarray]], None, None]: """ - load_tp_checkpoint_v1 + Iterate over tensor-parallel-sliced weights. """ - - safetensor_keys, safetensor_files = get_all_safetensors(model_path) - if use_fastsafetensor: weights_iterator = fastsafetensors_weights_iterator(safetensor_files) else: @@ -212,17 +216,20 @@ def load_tp_checkpoint_v1( safetensor_keys, ) need_tp = True if tensor_parallel_filtered_map else False - state_dict = {} for key, weight in weights_iterator: - paddle.device.synchronize() + if isinstance(weight, paddle.Tensor): + paddle.device.cuda.synchronize() if need_tp and key in tensor_parallel_filtered_map: action = tensor_parallel_filtered_map.pop(key) - tensor = action(weight).clone() + tensor = action(weight.detach().clone()) if isinstance( + weight, paddle.Tensor) else action(weight.copy()) else: - tensor = weight.clone() - state_dict[key] = tensor - weight.value().get_tensor()._clear() - return state_dict + tensor = weight.detach().clone() if isinstance( + weight, paddle.Tensor) else weight.copy() + if isinstance(weight, paddle.Tensor): + weight.value().get_tensor()._clear() + del weight + yield key, tensor def deal_state_dict(state_dict): @@ -238,6 +245,30 @@ def deal_state_dict(state_dict): src_tensor._share_data_with(dst_tensor) +def load_tp_checkpoint_v1( + model_path: str, + cls: PretrainedModel, + fd_config: FDConfig, + use_fastsafetensor: bool = True, +) -> Dict[str, Union[paddle.Tensor, np.ndarray]]: + """load_tp_checkpoint""" + safetensor_keys, safetensor_files = get_all_safetensors(model_path) + weights_iterator = tp_weights_iterator( + cls, + fd_config, + safetensor_keys, + safetensor_files, + use_fastsafetensor=use_fastsafetensor, + ) + state_dict = {} + for key, weight in weights_iterator: + state_dict[key] = weight + deal_state_dict(state_dict) + paddle.device.cuda.empty_cache() + paddle.device.cuda.synchronize() + return state_dict + + def load_composite_checkpoint( model_path: str, cls: PretrainedModel, @@ -274,11 +305,13 @@ def load_composite_checkpoint( if fd_config.load_config.use_fastsafetensor and ( current_platform.available() and current_platform.is_cuda()): - state_dict = load_tp_checkpoint_v1(model_path, - cls, - fd_config, - use_fastsafetensor=True) - deal_state_dict(state_dict) + state_dict = load_tp_checkpoint_v1( + model_path, + cls, + fd_config, + use_fastsafetensor=fd_config.load_config. + use_fastsafetensor, + ) else: state_dict = load_tp_checkpoint(model_path, cls, diff --git a/fastdeploy/model_executor/model_loader/__init__.py b/fastdeploy/model_executor/model_loader/__init__.py new file mode 100644 index 0000000000..fdaf619909 --- /dev/null +++ b/fastdeploy/model_executor/model_loader/__init__.py @@ -0,0 +1,33 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from fastdeploy.config import LoadConfig, LoadFormat, ModelConfig +from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader +from fastdeploy.model_executor.model_loader.default_loader import \ + DefaultModelLoader +from fastdeploy.model_executor.model_loader.load_time_quantization_loader import \ + LoadTimeQuantizationModelLoader + + +def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: + """get_model_loader""" + if load_config.load_format == LoadFormat.LoadTimeQuant: + return LoadTimeQuantizationModelLoader(load_config) + + return DefaultModelLoader(load_config) + + +__all__ = ["get_model_loader"] diff --git a/fastdeploy/model_executor/model_loader/base_loader.py b/fastdeploy/model_executor/model_loader/base_loader.py new file mode 100644 index 0000000000..9b2c8474ec --- /dev/null +++ b/fastdeploy/model_executor/model_loader/base_loader.py @@ -0,0 +1,37 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +from abc import ABC, abstractmethod + +from paddle import nn + +from fastdeploy.config import FDConfig, LoadConfig, ModelConfig + + +class BaseModelLoader(ABC): + """Base class for model loaders.""" + + def __init__(self, load_config: LoadConfig): + self.load_config = load_config + + @abstractmethod + def download_model(self, load_config: ModelConfig) -> None: + """Download a model so that it can be immediately loaded.""" + raise NotImplementedError + + @abstractmethod + def load_model(self, fd_config: FDConfig) -> nn.Layer: + """Load a model with the given configurations.""" + raise NotImplementedError diff --git a/fastdeploy/model_executor/model_loader.py b/fastdeploy/model_executor/model_loader/default_loader.py similarity index 55% rename from fastdeploy/model_executor/model_loader.py rename to fastdeploy/model_executor/model_loader/default_loader.py index 03ea7fcc67..bded7708df 100644 --- a/fastdeploy/model_executor/model_loader.py +++ b/fastdeploy/model_executor/model_loader/default_loader.py @@ -14,71 +14,31 @@ # limitations under the License. """ -from abc import ABC, abstractmethod - import paddle from paddle import nn +from paddleformers.utils.log import logger from fastdeploy.config import FDConfig, LoadConfig, ModelConfig from fastdeploy.model_executor.load_weight_utils import \ load_composite_checkpoint -from fastdeploy.model_executor.models.deepseek_v3 import \ - DeepSeekV3PretrainedModel -from fastdeploy.model_executor.models.ernie4_5_moe import \ - Ernie4_5_PretrainedModel -from fastdeploy.model_executor.models.ernie4_5_mtp import \ - Ernie4_5_MTPPretrainedModel +from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader +from fastdeploy.model_executor.model_loader.utils import get_pretrain_cls from fastdeploy.model_executor.models.model_base import ModelRegistry -from fastdeploy.model_executor.models.qwen2 import Qwen2PretrainedModel -from fastdeploy.model_executor.models.qwen3 import Qwen3PretrainedModel -from fastdeploy.model_executor.models.qwen3moe import Qwen3MoePretrainedModel from fastdeploy.platforms import current_platform -MODEL_CLASSES = { - "Ernie4_5_MoeForCausalLM": Ernie4_5_PretrainedModel, - "Ernie4_5_MTPForCausalLM": Ernie4_5_MTPPretrainedModel, - "Qwen2ForCausalLM": Qwen2PretrainedModel, - "Qwen3ForCausalLM": Qwen3PretrainedModel, - "Qwen3MoeForCausalLM": Qwen3MoePretrainedModel, - "Ernie4_5_ForCausalLM": Ernie4_5_PretrainedModel, - "DeepseekV3ForCausalLM": DeepSeekV3PretrainedModel, -} - - -def get_model_from_loader(fd_config: FDConfig) -> nn.Layer: - """ load or download model """ - model_loader = DefaultModelLoader(fd_config.load_config) - model = model_loader.load_model(fd_config) - return model - - -class BaseModelLoader(ABC): - """ Base class for model loaders. """ - - def __init__(self, load_config: LoadConfig): - self.load_config = load_config - - @abstractmethod - def download_model(self, load_config: ModelConfig) -> None: - """ Download a model so that it can be immediately loaded.""" - raise NotImplementedError - - @abstractmethod - def load_model(self, fd_config: FDConfig) -> nn.Layer: - """ Load a model with the given configurations.""" - raise NotImplementedError - class DefaultModelLoader(BaseModelLoader): """ ModelLoader that can load registered models """ def __init__(self, load_config: LoadConfig): super().__init__(load_config) + logger.info("Load the model and weights using DefaultModelLoader") def download_model(self, model_config: ModelConfig) -> None: + """download_model""" pass - def clean_memory_fragments(self, state_dict: dict) -> None: + def _clean_memory_fragments(self, state_dict: dict) -> None: """clean_memory_fragments""" if current_platform.is_cuda(): if state_dict: @@ -89,10 +49,11 @@ def clean_memory_fragments(self, state_dict: dict) -> None: paddle.device.synchronize() def load_model(self, fd_config: FDConfig) -> nn.Layer: + """load_model""" context = paddle.LazyGuard() architectures = fd_config.model_config.architectures[0] # TODO(gongshaotian): Now, only support safetensor - model_class = MODEL_CLASSES[architectures] + model_class = get_pretrain_cls(architectures) with context: model_cls = ModelRegistry.get_class(architectures) @@ -111,5 +72,5 @@ def load_model(self, fd_config: FDConfig) -> nn.Layer: return_numpy=True, ) model.set_state_dict(state_dict) - self.clean_memory_fragments(state_dict) + self._clean_memory_fragments(state_dict) return model diff --git a/fastdeploy/model_executor/model_loader/default_loader.py.rej b/fastdeploy/model_executor/model_loader/default_loader.py.rej new file mode 100644 index 0000000000..d07a89af47 --- /dev/null +++ b/fastdeploy/model_executor/model_loader/default_loader.py.rej @@ -0,0 +1,14 @@ +diff a/fastdeploy/model_executor/model_loader/default_loader.py b/fastdeploy/model_executor/model_loader/default_loader.py (rejected hunks) +@@ -89,10 +49,11 @@ class DefaultModelLoader(BaseModelLoader): + paddle.device.cuda.synchronize() + + def load_model(self, fd_config: FDConfig) -> nn.Layer: ++ """load_model""" + context = paddle.LazyGuard() + architectures = fd_config.model_config.architectures[0] + # TODO(gongshaotian): Now, only support safetensor +- model_class = MODEL_CLASSES[architectures] ++ model_class = get_pretrain_cls(architectures) + + with context: + model_cls = ModelRegistry.get_class(architectures) diff --git a/fastdeploy/model_executor/model_loader/load_time_quantization_loader.py b/fastdeploy/model_executor/model_loader/load_time_quantization_loader.py new file mode 100644 index 0000000000..78e0d9accc --- /dev/null +++ b/fastdeploy/model_executor/model_loader/load_time_quantization_loader.py @@ -0,0 +1,140 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from functools import partial +from typing import Dict, Generator, Union + +import numpy as np +import paddle +from paddle import nn +from paddleformers.utils.log import logger + +from fastdeploy.config import FDConfig, LoadConfig, ModelConfig +from fastdeploy.model_executor.load_weight_utils import (deal_state_dict, + get_all_safetensors, + tp_weights_iterator) +from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader +from fastdeploy.model_executor.model_loader.utils import get_pretrain_cls +from fastdeploy.model_executor.models.model_base import ModelRegistry +from fastdeploy.model_executor.models.quant_utils import ( + apply_quant_action, check_quantization_prerequisites, + get_quant_layer_instance_map) +from fastdeploy.model_executor.models.utils import switch_config_context +from fastdeploy.platforms import current_platform + + +class LoadTimeQuantizationModelLoader(BaseModelLoader): + """ModelLoader that can load registered models""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + logger.info( + "Load the model and weights using LoadTimeQuantizationModelLoader") + + def download_model(self, model_config: ModelConfig) -> None: + """download_model""" + pass + + def _clean_memory_fragments(self, state_dict: dict) -> None: + """clean_memory_fragments""" + if current_platform.is_cuda(): + if state_dict: + for k, v in state_dict.items(): + if isinstance(v, paddle.Tensor): + v.value().get_tensor()._clear() + paddle.device.cuda.empty_cache() + paddle.device.cuda.synchronize() + + def _get_quantized_weights( + self, + weight_iterator: Generator[tuple[str, Union[paddle.Tensor, + np.ndarray]], None, None], + quant_filtered_map: Dict[str, partial], + quant_layer_instance_map: Dict[str, nn.Layer], + ) -> Dict[str, paddle.Tensor]: + """_get_quantized_weights""" + state_dict = {} + for key, weight in weight_iterator: + if not isinstance(weight, paddle.Tensor): + weight = paddle.Tensor(weight, zero_copy=True) + weight = weight._copy_to( + paddle.framework._current_expected_place(), False) + if key in quant_filtered_map: + apply_quant_action( + quant_filtered_map, + key, + weight, + state_dict, + quant_layer_instance_map, + ) + else: + state_dict[key] = weight + deal_state_dict(state_dict) + paddle.device.cuda.empty_cache() + paddle.device.cuda.synchronize() + return state_dict + + def load_model(self, fd_config: FDConfig) -> nn.Layer: + """load_model""" + assert not fd_config.model_config.is_quantized + with switch_config_context(fd_config.model_config, "is_quantized", + True): + context = paddle.LazyGuard() + architectures = fd_config.model_config.architectures[0] + model_class = get_pretrain_cls(architectures) + model_cls = ModelRegistry.get_class(architectures) + + with context: + model = model_cls(fd_config) + model.eval() + safetensor_keys, safetensor_files = get_all_safetensors( + fd_config.parallel_config.model_name_or_path) + if not self.load_config.use_fastsafetensor: + logger.info( + "Tip: For faster model loading, consider enabling FastSafeTensor by setting:\n" + " export FD_USE_FASTSAFETENSOR=1") + weights_iterator = tp_weights_iterator( + model_class, + fd_config, + safetensor_keys, + safetensor_files, + use_fastsafetensor=self.load_config.use_fastsafetensor, + ) + model_dict = dict(model.named_sublayers()) + quant_filtered_map = {} + + check_quantization_prerequisites( + fd_config, + model_class, + quant_filtered_map, + safetensor_keys, + model_dict, + ) + need_quant = True if quant_filtered_map else False + quant_layer_instance_map = {} + if need_quant: + quant_layer_instance_map = get_quant_layer_instance_map( + model_class, model_dict) + need_quant = True if need_quant and quant_layer_instance_map else False + + assert need_quant, "Quantization must be enabled" + + state_dict = self._get_quantized_weights(weights_iterator, + quant_filtered_map, + quant_layer_instance_map) + model.set_state_dict(state_dict) + self._clean_memory_fragments(state_dict) + return model diff --git a/fastdeploy/model_executor/model_loader/utils.py b/fastdeploy/model_executor/model_loader/utils.py new file mode 100644 index 0000000000..933b78dec1 --- /dev/null +++ b/fastdeploy/model_executor/model_loader/utils.py @@ -0,0 +1,42 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from paddleformers.transformers import PretrainedModel + +from fastdeploy.model_executor.models.deepseek_v3 import \ + DeepSeekV3PretrainedModel +from fastdeploy.model_executor.models.ernie4_5_moe import \ + Ernie4_5_PretrainedModel +from fastdeploy.model_executor.models.ernie4_5_mtp import \ + Ernie4_5_MTPPretrainedModel +from fastdeploy.model_executor.models.qwen2 import Qwen2PretrainedModel +from fastdeploy.model_executor.models.qwen3 import Qwen3PretrainedModel +from fastdeploy.model_executor.models.qwen3moe import Qwen3MoePretrainedModel + +MODEL_CLASSES = { + "Ernie4_5_MoeForCausalLM": Ernie4_5_PretrainedModel, + "Ernie4_5_MTPForCausalLM": Ernie4_5_MTPPretrainedModel, + "Qwen2ForCausalLM": Qwen2PretrainedModel, + "Qwen3ForCausalLM": Qwen3PretrainedModel, + "Qwen3MoeForCausalLM": Qwen3MoePretrainedModel, + "Ernie4_5_ForCausalLM": Ernie4_5_PretrainedModel, + "DeepseekV3ForCausalLM": DeepSeekV3PretrainedModel +} + + +def get_pretrain_cls(architectures: str) -> PretrainedModel: + """get_pretrain_cls""" + return MODEL_CLASSES[architectures] diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index f6b73622a9..6c133742a4 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -94,30 +94,7 @@ class Ernie4_5_MoE(nn.Layer): def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str) -> None: super().__init__() - moe_quant_type = "" - if hasattr(fd_config.quant_config, 'moe_quant_type'): - moe_quant_type = fd_config.quant_config.moe_quant_type - - if moe_quant_type == "w4a8": - weight_key_map = { - "gate_weight_key": - f"{prefix}.gate.weight", - "gate_correction_bias_key": - f"{prefix}.moe_statics.e_score_correction_bias", - "ffn1_expert_weight_key": - f"{prefix}.experts.{{}}.up_gate_proj.quant_weight", - "ffn2_expert_weight_key": - f"{prefix}.experts.{{}}.down_proj.quant_weight", - "ffn1_expert_weight_scale_key": - f"{prefix}.experts.{{}}.up_gate_proj.weight_scale", - "ffn2_expert_weight_scale_key": - f"{prefix}.experts.{{}}.down_proj.weight_scale", - "ffn1_expert_in_scale_key": - f"{prefix}.experts.{{}}.up_gate_proj.activation_scale", - "ffn2_expert_in_scale_key": - f"{prefix}.experts.{{}}.down_proj.activation_scale", - } - elif moe_quant_type == "w4w2": + if fd_config.model_config.is_quantized: weight_key_map = { "gate_weight_key": f"{prefix}.gate.weight", @@ -143,23 +120,6 @@ def __init__(self, fd_config: FDConfig, layer_id: int, f"{prefix}.experts.{{}}.up_gate_proj.code_zp", "ffn2_expert_code_zp_key": f"{prefix}.experts.{{}}.down_proj.code_zp", - } - elif moe_quant_type == "tensor_wise_fp8" or ( - moe_quant_type == "block_wise_fp8" - and fd_config.model_config.is_quantized): - weight_key_map = { - "gate_weight_key": - f"{prefix}.gate.weight", - "gate_correction_bias_key": - f"{prefix}.moe_statics.e_score_correction_bias", - "ffn1_expert_weight_key": - f"{prefix}.experts.{{}}.up_gate_proj.quant_weight", - "ffn2_expert_weight_key": - f"{prefix}.experts.{{}}.down_proj.quant_weight", - "ffn1_expert_weight_scale_key": - f"{prefix}.experts.{{}}.up_gate_proj.weight_scale", - "ffn2_expert_weight_scale_key": - f"{prefix}.experts.{{}}.down_proj.weight_scale", "ffn1_expert_in_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.activation_scale", "ffn2_expert_in_scale_key": @@ -509,55 +469,77 @@ def _init_weight(self, layer): weight_infos = [ WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.qkv_proj.weight", - True, tsm.GQA), + True, tsm.GQA, "qkv_proj"), WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.o_proj.weight", - False), + False, None, "o_proj"), WeightMeta( f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.up_gate_proj.weight", - True, tsm.PairFused), + True, tsm.PairFused, "gate_up_proj"), WeightMeta(f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.down_proj.weight", - False), + False, None, "down_proj"), WeightMeta( f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.up_gate_proj.weight", - True, tsm.PairFused), + True, tsm.PairFused, "fused_moe"), WeightMeta( f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.down_proj.weight", - False), + False, None, "fused_moe"), WeightMeta( f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.up_gate_proj.weight", - True, tsm.PairFused), + True, tsm.PairFused, "gate_up_proj"), WeightMeta( f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.down_proj.weight", - False), + False, None, "down_proj"), WeightMeta(".embed_tokens.weight", False), WeightMeta("lm_head.weight", True), - # quant tensorwise - WeightMeta( - f".layers.{{{layerid.LAYER_ID}}}.self_attn.qkv_proj.quant_weight", - True, tsm.GQA), - WeightMeta( - f".layers.{{{layerid.LAYER_ID}}}.self_attn.o_proj.quant_weight", - False), - WeightMeta( - f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.up_gate_proj.quant_weight", - True, tsm.PairFused), - WeightMeta( - f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.down_proj.quant_weight", - False), - WeightMeta( - f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.up_gate_proj.quant_weight", - True, tsm.PairFused), - WeightMeta( - f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.down_proj.quant_weight", - False), - WeightMeta( - f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.up_gate_proj.quant_weight", - True, tsm.PairFused), - WeightMeta( - f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.down_proj.quant_weight", - False), ] + # quant_need_find_layer_list: names of model layers whose weights need quantization + # e.g., if the model defines `self.qkv_proj = QKVParallelLinear(...)` and qkv needs quantization, + # add "qkv_proj" to this list + quant_need_find_layer_list = { + "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "fused_moe" + } + + @classmethod + def _get_quantization_mappings(cls, fd_config: FDConfig): + """ + _get_quantization_mappings + """ + logger.info("erine bot inference model _get_quantization_mappings") + from fastdeploy.model_executor.models.quant_utils import \ + quantization_func + from fastdeploy.model_executor.models.tp_utils import \ + build_expanded_keys + + fn = quantization_func(fd_config) + + def get_tensor_quantization_mappings(fd_config: FDConfig): + base_actions = {} + for (weight_name, _, _, quant_layer_key) in cls.weight_infos: + if not quant_layer_key: + continue + params = { + "quant_layer_key": quant_layer_key, + } + key = f"{fd_config.model_config.prefix_name}{weight_name}" + base_actions[key] = partial(fn, **params) + final_actions = {} + start_layer = (fd_config.moe_config.moe_layer_start_index + if fd_config.moe_config.moe_layer_start_index > 0 + else fd_config.model_config.num_layers) + final_actions = build_expanded_keys( + fd_config.model_config.num_layers, + fd_config.moe_config.num_experts, + start_layer, + base_actions, + ) + + return final_actions + + mappings = get_tensor_quantization_mappings(fd_config) + + return mappings + @classmethod def _get_tensor_parallel_mappings(cls, config: ModelConfig, is_split=True): """ @@ -579,8 +561,9 @@ def get_tensor_parallel_split_mappings(num_layers, moe_num_experts, moe_layer_start_index, prefix_name): base_actions = {} - weight_infos = cls.weight_infos - for (weight_name, is_column, extra) in weight_infos: + weight_infos = list( + config.weight_infos_dict.values()) + cls.weight_infos + for (weight_name, is_column, extra, _) in weight_infos: params = { "is_column": is_column, **({ diff --git a/fastdeploy/model_executor/models/ernie4_5_mtp.py b/fastdeploy/model_executor/models/ernie4_5_mtp.py index 029becc1e4..8951daef83 100644 --- a/fastdeploy/model_executor/models/ernie4_5_mtp.py +++ b/fastdeploy/model_executor/models/ernie4_5_mtp.py @@ -53,139 +53,16 @@ def _get_tensor_parallel_mappings(cls, config: ModelConfig, is_split=True): """ logger.info("erine inference model _get_tensor_parallel_mappings") - from paddleformers.transformers.conversion_utils import \ - split_or_merge_func + from fastdeploy.model_executor.models.tp_utils import \ + split_or_merge_func_v1 - fn = split_or_merge_func( + fn = split_or_merge_func_v1( is_split=is_split, tensor_parallel_degree=config.tensor_parallel_degree, tensor_parallel_rank=config.tensor_parallel_rank, num_attention_heads=config.num_attention_heads, - ) - - def gqa_qkv_split_func( - weight, - tensor_parallel_degree, - tensor_parallel_rank, - num_attention_heads, - num_key_value_heads, - head_dim, - ): - - def get_shape(tensor): - return (tensor.get_shape() - if hasattr(tensor, "get_shape") else tensor.shape) - - def slice_tensor(tensor, start, end): - shape = get_shape(tensor) - if len(shape) == 1: - return tensor[start:end] - else: - return tensor[..., start:end] - - q_end = num_attention_heads * head_dim - k_end = q_end + num_key_value_heads * head_dim - v_end = k_end + num_key_value_heads * head_dim - - q = slice_tensor(weight, 0, q_end) - k = slice_tensor(weight, q_end, k_end) - v = slice_tensor(weight, k_end, v_end) - - def split_tensor(tensor, degree): - shape = get_shape(tensor) - size = shape[-1] - block_size = size // degree - if hasattr(tensor, "get_shape"): - return [ - slice_tensor(tensor, i * block_size, - (i + 1) * block_size) - for i in range(degree) - ] - else: - return np.split(tensor, degree, axis=-1) - - q_list = split_tensor(q, tensor_parallel_degree) - k_list = split_tensor(k, tensor_parallel_degree) - v_list = split_tensor(v, tensor_parallel_degree) - - if tensor_parallel_rank is None: - return [ - np.concatenate([q_i, k_i, v_i], axis=-1) - for q_i, k_i, v_i in zip(q_list, k_list, v_list) - ] - else: - return np.concatenate( - [ - q_list[tensor_parallel_rank], - k_list[tensor_parallel_rank], - v_list[tensor_parallel_rank], - ], - axis=-1, - ) - - def gqa_qkv_merge_func(weight_list, num_attention_heads, - num_key_value_heads, head_dim): - tensor_parallel_degree = len(weight_list) - num_attention_heads = num_attention_heads // tensor_parallel_degree - num_key_value_heads = num_key_value_heads // tensor_parallel_degree - - is_paddle_tensor = not isinstance(weight_list[0], np.ndarray) - - def get_shape(tensor): - return (tensor.get_shape() - if hasattr(tensor, "get_shape") else tensor.shape) - - def slice_tensor(tensor, start, end): - if len(get_shape(tensor)) == 1: - return tensor[start:end] - else: - return tensor[..., start:end] - - q_list, k_list, v_list = [], [], [] - - for weight in weight_list: - q_end = num_attention_heads * head_dim - k_end = q_end + num_key_value_heads * head_dim - v_end = k_end + num_key_value_heads * head_dim - - q = slice_tensor(weight, 0, q_end) - k = slice_tensor(weight, q_end, k_end) - v = slice_tensor(weight, k_end, v_end) - - q_list.append(q) - k_list.append(k) - v_list.append(v) - - merged = q_list + k_list + v_list - - if is_paddle_tensor: - tensor = paddle.concat(merged, axis=-1) - if tensor.place.is_gpu_place(): - tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False) - return tensor - else: - return np.concatenate(merged, axis=-1) - - if (config.num_key_value_heads is not None - and config.num_key_value_heads != config.num_attention_heads): - if is_split: - qkv_fn = partial( - gqa_qkv_split_func, - tensor_parallel_degree=config.tensor_parallel_degree, - tensor_parallel_rank=config.tensor_parallel_rank, - num_attention_heads=config.num_attention_heads, - num_key_value_heads=config.num_key_value_heads, - head_dim=config.hidden_size // config.num_attention_heads, - ) - else: - qkv_fn = partial( - gqa_qkv_merge_func, - num_attention_heads=config.num_attention_heads, - num_key_value_heads=config.num_key_value_heads, - head_dim=config.hidden_size // config.num_attention_heads, - ) - else: - qkv_fn = partial(fn, is_column=True) + num_key_value_heads=config.num_key_value_heads, + head_dim=config.head_dim) def get_tensor_parallel_split_mappings(num_layers, moe_num_experts, moe_layer_start_index): @@ -200,7 +77,8 @@ def get_tensor_parallel_split_mappings(num_layers, moe_num_experts, base_actions["ernie.mtp_linear_proj.0.weight"] = partial( fn, is_column=True) base_actions[ - f"{base_model_prefix}.0.self_attn.qkv_proj.weight"] = qkv_fn + f"{base_model_prefix}.0.self_attn.qkv_proj.weight"] = partial( + fn, is_column=True, is_gqa=True) base_actions[ f"{base_model_prefix}.0.self_attn.o_proj.weight"] = partial( fn, is_column=False) @@ -261,7 +139,7 @@ def __init__( """ super().__init__() - + fd_config.model_config.prefix_name = "ernie.mtp_block" self.num_layers = fd_config.model_config.num_layers self.embeddings = fd_config.speculative_config.sharing_model.model.embeddings diff --git a/fastdeploy/model_executor/models/quant_utils.py b/fastdeploy/model_executor/models/quant_utils.py new file mode 100644 index 0000000000..2218bece35 --- /dev/null +++ b/fastdeploy/model_executor/models/quant_utils.py @@ -0,0 +1,133 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from functools import partial +from typing import Dict, Optional + +import paddle +from paddle import nn +from paddleformers.transformers import PretrainedModel +from paddleformers.utils.log import logger + +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.models.utils import switch_level_context + + +def get_quant_layer_instance_map( + cls: PretrainedModel, + model_dict: Dict[str, nn.Layer]) -> Dict[str, nn.Layer]: + """get_quant_layer_instance_map""" + suffix_set = set(cls.quant_need_find_layer_list) + quant_layer_map = {} + + remaining_suffixes = set(suffix_set) + + for key, layer in model_dict.items(): + for suffix in list(remaining_suffixes): + if key.endswith(suffix): + quant_layer_map[suffix] = layer + remaining_suffixes.remove(suffix) + break + + if not remaining_suffixes: + break + + if not quant_layer_map: + logger.error( + "quant_map should not be empty. " + "Pre-quantization is required, but _get_quantization_mappings is not implemented." + ) + + return quant_layer_map + + +def apply_quant_action( + quant_map: Dict[str, partial], + key: str, + tensor: paddle.Tensor, + state_dict: Dict[str, paddle.Tensor], + quant_layer_instance_map: Dict[str, nn.Layer], +) -> None: + """ + apply_quant_action + """ + action = quant_map.pop(key) + quant_weight_tensor, weight_quanter_tensor = action( + key, tensor, quant_layer_instance_map) + if quant_weight_tensor._is_initialized(): + quant_weight_key = key.replace("weight", "quant_weight") + state_dict[quant_weight_key] = quant_weight_tensor + if weight_quanter_tensor._is_initialized(): + weight_quanter_key = key.replace("weight", "weight_scale") + state_dict[weight_quanter_key] = weight_quanter_tensor + + +@switch_level_context("WARNING") +def check_quantization_prerequisites( + fd_config: FDConfig, + cls: PretrainedModel, + quant_filtered_map: Dict[str, partial], + safetensor_keys: list[str], + model_dict: Optional[Dict[str, nn.Layer]] = None, +) -> None: + """check_quantization_prerequisites""" + if not hasattr(cls, "_get_quantization_mappings"): + raise NotImplementedError( + f"Class {cls.__name__} must implement method '_get_quantization_mappings'" + ) + quant_map = cls._get_quantization_mappings(fd_config) + if not quant_map: + logger.error("quant_map should not be empty. \ + pre-quantization required, but _get_quantization_mappings is not implemented." + ) + else: + filtered_quant_map = cls._resolve_prefix_keys(quant_map.keys(), + safetensor_keys) + for k, v in filtered_quant_map.items(): + quant_filtered_map[v] = quant_map.pop(k) + if not filtered_quant_map: + logger.info("No weights to quantize; filtered_quant_map is empty.") + else: + if not model_dict: + logger.error( + "Missing required argument 'model_dict' when calling tp_weights_iterator." + ) + + +def quantization_func(fd_config: FDConfig): + """quantization_func""" + + def fn( + key: str, + tensor: paddle.Tensor, + quant_layer_instance_map: Dict[str, nn.Layer], + quant_layer_key: str = "", + ): + """fn""" + quant_layer = quant_layer_instance_map[quant_layer_key] + quant_method = fd_config.quant_config.get_quant_method(quant_layer) + if quant_method is None: + raise ValueError("quant_method should not be None.") + try: + (quanted_weight_tensor, weight_scale_tensor) = ( + quant_method.apply_weight_quantization(tensor)) + except Exception: + raise ValueError( + f"{key} Expected apply_weight_quantization is missing from {quant_method}" + ) + return quanted_weight_tensor, weight_scale_tensor + + return fn diff --git a/fastdeploy/model_executor/models/tp_utils.py b/fastdeploy/model_executor/models/tp_utils.py index f360c5106f..08989188a8 100644 --- a/fastdeploy/model_executor/models/tp_utils.py +++ b/fastdeploy/model_executor/models/tp_utils.py @@ -26,9 +26,28 @@ from paddleformers.utils.log import logger from fastdeploy.config import FDConfig -from fastdeploy.model_executor.models.utils import LayerIdPlaceholder +from fastdeploy.model_executor.models.utils import (LayerIdPlaceholder, + switch_level_context) +class TensorSplitMode(Enum): + """TensorSplitMode""" + + GQA = "is_gqa" + TRANSPOSE = "transpose" + QKV = "is_old_qkv" + PairFused = "is_naive_2fuse" + TripletFused = "is_naive_3fuse" + + +class SafeDict(dict): + """SafeDict""" + + def __missing__(self, key): + return "{" + key + "}" + + +@switch_level_context("WARNING") def check_tensor_parallel_prerequisites( fd_config: FDConfig, cls: PretrainedModel, @@ -66,28 +85,11 @@ def has_prefix(prefix_name: str, weight_name: str): return prefix_name == extract_prefix(weight_name) -class TensorSplitMode(Enum): - """TensorSplitMode""" - - GQA = "is_gqa" - TRANSPOSE = "transpose" - QKV = "is_old_qkv" - PairFused = "is_naive_2fuse" - TripletFused = "is_naive_3fuse" - - def extract_placeholders(template: str): """extract_placeholders""" return set(re.findall(r"{(\w+)}", template)) -class SafeDict(dict): - """SafeDict""" - - def __missing__(self, key): - return "{" + key + "}" - - def has_placeholders(placeholders): """has_placeholders""" return len(placeholders) > 0 diff --git a/fastdeploy/model_executor/models/utils.py b/fastdeploy/model_executor/models/utils.py index 350f10651f..d141db3702 100644 --- a/fastdeploy/model_executor/models/utils.py +++ b/fastdeploy/model_executor/models/utils.py @@ -23,6 +23,7 @@ import random import re import struct +from contextlib import contextmanager from functools import partial from typing import NamedTuple, Optional @@ -44,6 +45,29 @@ MAX_DRAFT_TOKENS = 6 +@contextmanager +def switch_config_context(config_obj, config_attr_name, value): + """switch_config_context""" + origin_value = getattr(config_obj, config_attr_name) + setattr(config_obj, config_attr_name, value) + try: + yield + finally: + setattr(config_obj, config_attr_name, origin_value) + + +@contextmanager +def switch_level_context(level="ERROR"): + """switch_level_context""" + original_level = logger.logLevel + logger.set_level(level) + + try: + yield + finally: + logger.set_level(original_level) + + class LayerIdPlaceholder(str, enum.Enum): """LayerIdPlaceholder""" LAYER_ID = "layer_id" @@ -59,10 +83,14 @@ class WeightMeta(NamedTuple): # weight_name: weight name # is_column: whether to split by columns # extra: optional flags like "is_naive_2fuse", "is_gqa", "is_naive_3fuse" + + # quantization parameters + # layer_key: layer name this weight belongs to in the model """ weight_name: str is_column: bool extra: Optional[str] = None + layer_key: Optional[str] = None class UniqueIDGenerator: diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 97e8364451..927af828a4 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -73,7 +73,6 @@ def _update_cfg(self, main_model): self.model_config.num_layers = 1 self.parallel_config.model_name_or_path = ( self.speculative_config.model_name_or_path) - self.model_config.prefix_name = "ernie.mtp_block" if self.speculative_config.quantization != "": self.model_config.quantization = ( self.speculative_config.quantization) @@ -84,10 +83,14 @@ def _load_model(self): """ Load MTP Layer """ - from fastdeploy.model_executor.model_loader import \ - get_model_from_loader - - self.model = get_model_from_loader(self.cfg) + from fastdeploy.config import LoadFormat + from fastdeploy.model_executor.model_loader import get_model_loader + from fastdeploy.model_executor.models.utils import \ + switch_config_context + with switch_config_context(self.cfg.load_config, "load_format", + LoadFormat.DEFAULT.value): + model_loader = get_model_loader(load_config=self.cfg.load_config) + self.model = model_loader.load_model(fd_config=self.cfg) def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int): diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 8d6ca79a1b..f1670e692f 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -34,7 +34,7 @@ from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.sampler import ( Sampler, SpeculativeSampler) -from fastdeploy.model_executor.model_loader import get_model_from_loader +from fastdeploy.model_executor.model_loader import get_model_loader from fastdeploy.model_executor.ops.gpu import (set_value_by_flags_and_idx, share_external_data) from fastdeploy.model_executor.pre_and_post_process import (post_process, @@ -590,7 +590,8 @@ def load_model(self) -> None: f"Starting to load model {self.model_config.architectures[0]}") time_before_load = time.perf_counter() # 1. Load original model - self.model = get_model_from_loader(fd_config=self.fd_config) + model_loader = get_model_loader(load_config=self.fd_config.load_config) + self.model = model_loader.load_model(fd_config=self.fd_config) # 1.1 Load RL dynamic model if self.fd_config.load_config.dynamic_load_weight: from fastdeploy.rl.dynamic_weight_manager import \ diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index ba7a5541a5..e655d3915a 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -552,6 +552,12 @@ def parse_args(): "'meta': provide RL traing worker, no_weights_load" "'normal':normal load weight") + parser.add_argument( + "--load_format", + type=str, + default="default", + help="The format of the model weights to load. default/load_time_quant." + ) args = parser.parse_args() return args @@ -626,6 +632,7 @@ def initialize_fd_config(args: argparse.Namespace) -> FDConfig: parallel_config.expert_parallel_degree = args.expert_parallel_size parallel_config.splitwise_role = args.splitwise_role load_config.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1 + load_config.load_format = args.load_format parallel_config.guided_decoding_backend = args.guided_decoding_backend parallel_config.disable_any_whitespace = args.disable_any_whitespace diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index b075356f99..cab064b6a3 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -29,7 +29,7 @@ from fastdeploy.model_executor.layers.rotary_embedding import get_rope from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.sampler import Sampler -from fastdeploy.model_executor.model_loader import get_model_from_loader +from fastdeploy.model_executor.model_loader import get_model_loader from fastdeploy.utils import get_logger from fastdeploy.worker.forward_meta import ForwardMeta, XPUForwardMeta from fastdeploy.worker.model_runner_base import ModelRunnerBase @@ -514,7 +514,8 @@ def load_model(self) -> None: f"Starting to load model {self.model_config.architectures[0]}") time_before_load = time.perf_counter() # 1. Load original model - self.model = get_model_from_loader(fd_config=self.fd_config) + model_loader = get_model_loader(load_config=self.fd_config.load_config) + self.model = model_loader.load_model(fd_config=self.fd_config) # 2. Load lora model