From 24a0e8c00844310d156d2177028d54b9538661dd Mon Sep 17 00:00:00 2001 From: zeroRains Date: Mon, 7 Jul 2025 20:36:10 +0800 Subject: [PATCH] support use safetensors with paddle.MmapStorage to load model files Change-Id: I8f6faff3d86b682ccdccc31b38eb1d6b1db5e8a1 --- .../model_executor/load_weight_utils.py | 230 +++++++++++++++++- 1 file changed, 225 insertions(+), 5 deletions(-) diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index c8ba1f673b..3b7d5f7d87 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -14,14 +14,19 @@ # limitations under the License. """ +import concurrent +import concurrent.futures +import contextlib import json import os +import re +from typing import Union +import numpy as np import paddle import paddle.distributed as dist from fastsafetensors import SafeTensorsFileLoader, SingleGroup from paddleformers.transformers import PretrainedModel -from paddleformers.transformers.model_utils import load_tp_checkpoint from safetensors import safe_open from tqdm import tqdm @@ -78,7 +83,7 @@ def load_ep_checkpoint(model_path: str, desc="Loading safetensor files", unit="file"): with safe_open(os.path.join(model_path, safetensor_path), - framework="np", + framework="pp", device="cpu") as f: # Check if this file contains keys from filtered_map for k in filtered_map: @@ -92,7 +97,7 @@ def load_ep_checkpoint(model_path: str, return state_dict -def safetensors_weights_iterator(safe_tensor_list: list[str], ): +def safetensors_weights_iterator(safe_tensor_list: list[str] ): """ safetensors_weights_iterator """ @@ -100,7 +105,7 @@ def safetensors_weights_iterator(safe_tensor_list: list[str], ): safe_tensor_list, desc="Loading safetensors checkpoint shards", ): - with safe_open(st_file, framework="np") as f: + with safe_open(st_file, framework="pp") as f: for name in f.keys(): param = f.get_tensor(name) yield name, param @@ -170,7 +175,7 @@ def get_all_safetensors(model_path: str): safe_model_path = os.path.join(model_path, "model.safetensors") if os.path.exists(safe_model_path): safetensor_list = [safe_model_path] - with safe_open(safe_model_path, framework="np", device="cpu") as f: + with safe_open(safe_model_path, framework="pp", device="cpu") as f: key_name_list = f.keys() return key_name_list, safetensor_list else: @@ -187,6 +192,221 @@ def get_all_safetensors(model_path: str): return key_name_list, safetensor_list + +def _add_variant(weights_name: str, variant=None) -> str: + if variant is not None and len(variant) > 0: + splits = weights_name.split(".") + splits = splits[:-1] + [variant] + splits[-1:] + weights_name = ".".join(splits) + + return weights_name + +@contextlib.contextmanager +def device_guard(device="cpu", dev_id=0): + origin_device = paddle.device.get_device() + if device == "cpu": + paddle.set_device(device) + elif device in ["gpu", "xpu", "npu"]: + paddle.set_device("{}:{}".format(device, dev_id)) + try: + yield + finally: + paddle.set_device(origin_device) + +def _split_keys_evenly(keys: list, n: int) -> list: + + total_len = len(keys) + base_size = total_len // n + extra = total_len % n + + result = [] + index = 0 + for _ in range(n): + part_size = base_size + 1 if extra > 0 else base_size + extra -= 1 + result.append(keys[index : index + part_size]) + index += part_size + + return result + +def load_sharded_checkpoint_as_one(folder, variant=None, return_numpy=False): + pdparams_file = os.path.join(folder, _add_variant("model_state.pdparams", variant)) + lora_pdparams_file = os.path.join(folder, _add_variant("lora_model_state.pdparams", variant)) + safetensors_file = os.path.join(folder, _add_variant("model.safetensors", variant)) + if os.path.isfile(pdparams_file): + return paddle.load(pdparams_file, return_numpy=return_numpy) + if os.path.isfile(lora_pdparams_file): + return paddle.load(lora_pdparams_file, return_numpy=return_numpy) + if os.path.isfile(safetensors_file): + state_dict = {} + with safe_open(safetensors_file, framework="pp") as f: + for key in f.keys(): + state_dict[key] = f.get_tensor() + if not return_numpy: + for key in list(state_dict.keys()): + if isinstance(state_dict[key], np.ndarray): + state_dict[key] = paddle.Tensor.__call__(state_dict.pop(key), zero_copy=True) + return state_dict + + PADDLE_WEIGHTS_INDEX_NAME = "model_state.pdparams.index.json" + SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" + SAFE_MASTER_WEIGHTS_INDEX_NAME = "master_weights.safetensors.index.json" + SAFE_PEFT_WEIGHTS_INDEX_NAME = "peft_model.safetensors.index.json" + + index_file = os.path.join(folder, _add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant)) + safe_index_file = os.path.join(folder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)) + safe_master_file = os.path.join(folder, _add_variant(SAFE_MASTER_WEIGHTS_INDEX_NAME, variant)) + safe_peft_file = os.path.join(folder, _add_variant(SAFE_PEFT_WEIGHTS_INDEX_NAME, variant)) + + index_present = os.path.isfile(index_file) + safe_index_present = os.path.isfile(safe_index_file) + safe_master_present = os.path.isfile(safe_master_file) + safe_peft_present = os.path.isfile(safe_peft_file) + + load_index = None + if safe_index_present: + load_index = safe_index_file + elif safe_master_present: + load_index = safe_master_file + elif index_present: + load_index = index_file + elif safe_peft_present: + load_index = safe_peft_file + else: + raise ValueError(f"Could not find {index_file} or {safe_index_file} or {safe_peft_file}") + + with open(load_index, "r", encoding="utf-8") as f: + index = json.load(f) + + shard_files = list(set(index["weight_map"].values())) + ret = {} + for shard_file in tqdm(shard_files): + with safe_open(os.path.join(folder, shard_file), framework="pp") as f: + for key in f.keys(): + ret[key] = f.get_tensor(key) + if not return_numpy: + for key in list(ret.keys()): + if isinstance(ret[key], np.ndarray): + ret[key] = paddle.Tensor.__call__(ret.pop(key), zero_copy=True) + return ret + +def _load_part_state_dict( + keys, + checkpoint_file: Union[str, os.PathLike], + tensor_parallel_split_mapping, + fliter_dict_keys, + return_numpy=False, +): + part_state_dict = {} + with safe_open(checkpoint_file, framework="pp") as f: + for key in keys: + py_safe_slice_ = f.get_tensor(key) + if key in tensor_parallel_split_mapping: + weight = tensor_parallel_split_mapping[key](py_safe_slice_) + else: + weight = py_safe_slice_ + if not return_numpy: + with device_guard(): + weight = paddle.Tensor.__call__(weight, zero_copy=True) + weight = weight._copy_to(paddle.framework._current_expected_place(), False) + part_state_dict[key] = weight + return part_state_dict + +def load_tp_state_dict(checkpoint_file: Union[str, os.PathLike], + tensor_parallel_split_mapping=None, + fliter_dict_keys=None, + device="cpu", + return_numpy=False): + + if tensor_parallel_split_mapping is None: + tensor_parallel_split_mapping = {} + + if ( + checkpoint_file.endswith(".safetensors") or re.search(r"\.safetensors_shard_\d{4}$", checkpoint_file) + ): + thread_num = int(os.environ.get("LOAD_STATE_DICT_THREAD_NUM", "1")) + state_dict = {} + if thread_num <= 1: + with safe_open(checkpoint_file, framework="pp") as f: + state_dict = _load_part_state_dict( + list(f.keys()), + checkpoint_file, + tensor_parallel_split_mapping, + fliter_dict_keys, + return_numpy, + ) + else: + # Load state dict in multi-thread to speed up loading + with safe_open(checkpoint_file, framework="pp") as f: + keys_groups = _split_keys_evenly(list(f.keys()), thread_num) + with concurrent.futures.ThreadPoolExecutor(max_workers=thread_num) as executor: + future_to_key = { + executor.submit( + _load_part_state_dict, + keys, + checkpoint_file, + tensor_parallel_split_mapping, + fliter_dict_keys, + return_numpy, + ): keys + for keys in keys_groups + } + for future in concurrent.futures.as_completed(future_to_key): + res_state_dict = future.result() + state_dict.update(res_state_dict) + + if not return_numpy: + if device == "cpu": + with device_guard(): + for k in list(state_dict.keys()): + state_dict[k] = paddle.Tensor.__call__(state_dict.pop(k), zero_copy=True) + elif device == "pin_memory": + for k in list(state_dict.keys()): + state_dict[k] = paddle.to_tensor(state_dict.pop(k), place=paddle.CUDAPinnedPlace()) + + return state_dict + + +def load_tp_checkpoint( + folder: str, + cls: PretrainedModel, + config: ModelConfig, + return_numpy: bool = True, +): + if config.tensor_parallel_degree == 1 or config.tensor_parallel_degree == -1: + return load_sharded_checkpoint_as_one(folder, return_numpy=return_numpy) + rank_model_path = os.path.join(folder, f"model_state.tp0{config.tensor_parallel_rank}.pdparams") + model_path = os.path.join(folder, "model_state.pdparams") + safe_model_path = os.path.join(folder, "model.safetensors") + if os.path.exists(rank_model_path): + return paddle.load(rank_model_path, return_numpy=return_numpy) + elif os.path.exists(model_path): + state_dict = cls.convert_tensor_parallel(model_path, config) + elif os.path.exists(safe_model_path): + with safe_open(safe_model_path, framework="pp", device="cpu") as f: + loaded_keys = f.keys() + tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys) + state_dict = load_tp_state_dict(safe_model_path, tp_actions, return_numpy=return_numpy) + else: # shard files safetensors + resolved_archive_file, resolved_sharded_files, sharded_metadata, is_sharded = cls._resolve_model_file_path( + pretrained_model_name_or_path=folder, + use_safetensors=True, + ) + if len(resolved_sharded_files) > 1: + resolved_sharded_files = tqdm(resolved_sharded_files, desc="Loading checkpoint shards") + loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] + tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_state_dict_keys, ignore_error=True) + state_dict = {} + for shard_file in resolved_sharded_files: + shard_state_dict = load_tp_state_dict( # todo: for this function + shard_file, + tp_actions, + loaded_state_dict_keys, + return_numpy=return_numpy, + ) + state_dict.update(shard_state_dict) + return state_dict + def load_tp_checkpoint_v1( model_path: str, cls: PretrainedModel,