Skip to content

Commit 52b4ac6

Browse files
author
zeroRains
committed
support use safetensors with paddle.MmapStorage to load model files
Change-Id: I8f0d28e6c864647817183d2eb299a9bbd63e7851
1 parent a6e9161 commit 52b4ac6

File tree

1 file changed

+225
-5
lines changed

1 file changed

+225
-5
lines changed

fastdeploy/model_executor/load_weight_utils.py

Lines changed: 225 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,19 @@
1414
# limitations under the License.
1515
"""
1616

17+
import concurrent
18+
import concurrent.futures
19+
import contextlib
1720
import json
1821
import os
22+
import re
23+
from typing import Union
1924

25+
import numpy as np
2026
import paddle
2127
import paddle.distributed as dist
2228
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
2329
from paddleformers.transformers import PretrainedModel
24-
from paddleformers.transformers.model_utils import load_tp_checkpoint
2530
from safetensors import safe_open
2631
from tqdm import tqdm
2732

@@ -78,7 +83,7 @@ def load_ep_checkpoint(model_path: str,
7883
desc="Loading safetensor files",
7984
unit="file"):
8085
with safe_open(os.path.join(model_path, safetensor_path),
81-
framework="np",
86+
framework="pp",
8287
device="cpu") as f:
8388
# Check if this file contains keys from filtered_map
8489
for k in filtered_map:
@@ -92,15 +97,15 @@ def load_ep_checkpoint(model_path: str,
9297
return state_dict
9398

9499

95-
def safetensors_weights_iterator(safe_tensor_list: list[str], ):
100+
def safetensors_weights_iterator(safe_tensor_list: list[str] ):
96101
"""
97102
safetensors_weights_iterator
98103
"""
99104
for st_file in tqdm(
100105
safe_tensor_list,
101106
desc="Loading safetensors checkpoint shards",
102107
):
103-
with safe_open(st_file, framework="np") as f:
108+
with safe_open(st_file, framework="pp") as f:
104109
for name in f.keys():
105110
param = f.get_tensor(name)
106111
yield name, param
@@ -170,7 +175,7 @@ def get_all_safetensors(model_path: str):
170175
safe_model_path = os.path.join(model_path, "model.safetensors")
171176
if os.path.exists(safe_model_path):
172177
safetensor_list = [safe_model_path]
173-
with safe_open(safe_model_path, framework="np", device="cpu") as f:
178+
with safe_open(safe_model_path, framework="pp", device="cpu") as f:
174179
key_name_list = f.keys()
175180
return key_name_list, safetensor_list
176181
else:
@@ -187,6 +192,221 @@ def get_all_safetensors(model_path: str):
187192
return key_name_list, safetensor_list
188193

189194

195+
196+
def _add_variant(weights_name: str, variant=None) -> str:
197+
if variant is not None and len(variant) > 0:
198+
splits = weights_name.split(".")
199+
splits = splits[:-1] + [variant] + splits[-1:]
200+
weights_name = ".".join(splits)
201+
202+
return weights_name
203+
204+
@contextlib.contextmanager
205+
def device_guard(device="cpu", dev_id=0):
206+
origin_device = paddle.device.get_device()
207+
if device == "cpu":
208+
paddle.set_device(device)
209+
elif device in ["gpu", "xpu", "npu"]:
210+
paddle.set_device("{}:{}".format(device, dev_id))
211+
try:
212+
yield
213+
finally:
214+
paddle.set_device(origin_device)
215+
216+
def _split_keys_evenly(keys: list, n: int) -> list:
217+
218+
total_len = len(keys)
219+
base_size = total_len // n
220+
extra = total_len % n
221+
222+
result = []
223+
index = 0
224+
for _ in range(n):
225+
part_size = base_size + 1 if extra > 0 else base_size
226+
extra -= 1
227+
result.append(keys[index : index + part_size])
228+
index += part_size
229+
230+
return result
231+
232+
def load_sharded_checkpoint_as_one(folder, variant=None, return_numpy=False):
233+
pdparams_file = os.path.join(folder, _add_variant("model_state.pdparams", variant))
234+
lora_pdparams_file = os.path.join(folder, _add_variant("lora_model_state.pdparams", variant))
235+
safetensors_file = os.path.join(folder, _add_variant("model.safetensors", variant))
236+
if os.path.isfile(pdparams_file):
237+
return paddle.load(pdparams_file, return_numpy=return_numpy)
238+
if os.path.isfile(lora_pdparams_file):
239+
return paddle.load(lora_pdparams_file, return_numpy=return_numpy)
240+
if os.path.isfile(safetensors_file):
241+
state_dict = {}
242+
with safe_open(safetensors_file, framework="pp") as f:
243+
for key in f.keys():
244+
state_dict[key] = f.get_tensor()
245+
if not return_numpy:
246+
for key in list(state_dict.keys()):
247+
if isinstance(state_dict[key], np.ndarray):
248+
state_dict[key] = paddle.Tensor.__call__(state_dict.pop(key), zero_copy=True)
249+
return state_dict
250+
251+
PADDLE_WEIGHTS_INDEX_NAME = "model_state.pdparams.index.json"
252+
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
253+
SAFE_MASTER_WEIGHTS_INDEX_NAME = "master_weights.safetensors.index.json"
254+
SAFE_PEFT_WEIGHTS_INDEX_NAME = "peft_model.safetensors.index.json"
255+
256+
index_file = os.path.join(folder, _add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant))
257+
safe_index_file = os.path.join(folder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant))
258+
safe_master_file = os.path.join(folder, _add_variant(SAFE_MASTER_WEIGHTS_INDEX_NAME, variant))
259+
safe_peft_file = os.path.join(folder, _add_variant(SAFE_PEFT_WEIGHTS_INDEX_NAME, variant))
260+
261+
index_present = os.path.isfile(index_file)
262+
safe_index_present = os.path.isfile(safe_index_file)
263+
safe_master_present = os.path.isfile(safe_master_file)
264+
safe_peft_present = os.path.isfile(safe_peft_file)
265+
266+
load_index = None
267+
if safe_index_present:
268+
load_index = safe_index_file
269+
elif safe_master_present:
270+
load_index = safe_master_file
271+
elif index_present:
272+
load_index = index_file
273+
elif safe_peft_present:
274+
load_index = safe_peft_file
275+
else:
276+
raise ValueError(f"Could not find {index_file} or {safe_index_file} or {safe_peft_file}")
277+
278+
with open(load_index, "r", encoding="utf-8") as f:
279+
index = json.load(f)
280+
281+
shard_files = list(set(index["weight_map"].values()))
282+
ret = {}
283+
for shard_file in tqdm(shard_files):
284+
with safe_open(os.path.join(folder, shard_file), framework="pp") as f:
285+
for key in f.keys():
286+
ret[key] = f.get_tensor(key)
287+
if not return_numpy:
288+
for key in list(ret.keys()):
289+
if isinstance(ret[key], np.ndarray):
290+
ret[key] = paddle.Tensor.__call__(ret.pop(key), zero_copy=True)
291+
return ret
292+
293+
def _load_part_state_dict(
294+
keys,
295+
checkpoint_file: Union[str, os.PathLike],
296+
tensor_parallel_split_mapping,
297+
fliter_dict_keys,
298+
return_numpy=False,
299+
):
300+
part_state_dict = {}
301+
with safe_open(checkpoint_file, framework="pp") as f:
302+
for key in keys:
303+
py_safe_slice_ = f.get_tensor(key)
304+
if key in tensor_parallel_split_mapping:
305+
weight = tensor_parallel_split_mapping[key](py_safe_slice_)
306+
else:
307+
weight = py_safe_slice_
308+
if not return_numpy:
309+
with device_guard():
310+
weight = paddle.Tensor.__call__(weight, zero_copy=True)
311+
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
312+
part_state_dict[key] = weight
313+
return part_state_dict
314+
315+
def load_tp_state_dict(checkpoint_file: Union[str, os.PathLike],
316+
tensor_parallel_split_mapping=None,
317+
fliter_dict_keys=None,
318+
device="cpu",
319+
return_numpy=False):
320+
321+
if tensor_parallel_split_mapping is None:
322+
tensor_parallel_split_mapping = {}
323+
324+
if (
325+
checkpoint_file.endswith(".safetensors") or re.search(r"\.safetensors_shard_\d{4}$", checkpoint_file)
326+
):
327+
thread_num = int(os.environ.get("LOAD_STATE_DICT_THREAD_NUM", "1"))
328+
state_dict = {}
329+
if thread_num <= 1:
330+
with safe_open(checkpoint_file, framework="pp") as f:
331+
state_dict = _load_part_state_dict(
332+
list(f.keys()),
333+
checkpoint_file,
334+
tensor_parallel_split_mapping,
335+
fliter_dict_keys,
336+
return_numpy,
337+
)
338+
else:
339+
# Load state dict in multi-thread to speed up loading
340+
with safe_open(checkpoint_file, framework="pp") as f:
341+
keys_groups = _split_keys_evenly(list(f.keys()), thread_num)
342+
with concurrent.futures.ThreadPoolExecutor(max_workers=thread_num) as executor:
343+
future_to_key = {
344+
executor.submit(
345+
_load_part_state_dict,
346+
keys,
347+
checkpoint_file,
348+
tensor_parallel_split_mapping,
349+
fliter_dict_keys,
350+
return_numpy,
351+
): keys
352+
for keys in keys_groups
353+
}
354+
for future in concurrent.futures.as_completed(future_to_key):
355+
res_state_dict = future.result()
356+
state_dict.update(res_state_dict)
357+
358+
if not return_numpy:
359+
if device == "cpu":
360+
with device_guard():
361+
for k in list(state_dict.keys()):
362+
state_dict[k] = paddle.Tensor.__call__(state_dict.pop(k), zero_copy=True)
363+
elif device == "pin_memory":
364+
for k in list(state_dict.keys()):
365+
state_dict[k] = paddle.to_tensor(state_dict.pop(k), place=paddle.CUDAPinnedPlace())
366+
367+
return state_dict
368+
369+
370+
def load_tp_checkpoint(
371+
folder: str,
372+
cls: PretrainedModel,
373+
config: ModelConfig,
374+
return_numpy: bool = True,
375+
):
376+
if config.tensor_parallel_degree == 1 or config.tensor_parallel_degree == -1:
377+
return load_sharded_checkpoint_as_one(folder, return_numpy=return_numpy)
378+
rank_model_path = os.path.join(folder, f"model_state.tp0{config.tensor_parallel_rank}.pdparams")
379+
model_path = os.path.join(folder, "model_state.pdparams")
380+
safe_model_path = os.path.join(folder, "model.safetensors")
381+
if os.path.exists(rank_model_path):
382+
return paddle.load(rank_model_path, return_numpy=return_numpy)
383+
elif os.path.exists(model_path):
384+
state_dict = cls.convert_tensor_parallel(model_path, config)
385+
elif os.path.exists(safe_model_path):
386+
with safe_open(safe_model_path, framework="pp", device="cpu") as f:
387+
loaded_keys = f.keys()
388+
tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys)
389+
state_dict = load_tp_state_dict(safe_model_path, tp_actions, return_numpy=return_numpy)
390+
else: # shard files safetensors
391+
resolved_archive_file, resolved_sharded_files, sharded_metadata, is_sharded = cls._resolve_model_file_path(
392+
pretrained_model_name_or_path=folder,
393+
use_safetensors=True,
394+
)
395+
if len(resolved_sharded_files) > 1:
396+
resolved_sharded_files = tqdm(resolved_sharded_files, desc="Loading checkpoint shards")
397+
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
398+
tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_state_dict_keys, ignore_error=True)
399+
state_dict = {}
400+
for shard_file in resolved_sharded_files:
401+
shard_state_dict = load_tp_state_dict( # todo: for this function
402+
shard_file,
403+
tp_actions,
404+
loaded_state_dict_keys,
405+
return_numpy=return_numpy,
406+
)
407+
state_dict.update(shard_state_dict)
408+
return state_dict
409+
190410
def load_tp_checkpoint_v1(
191411
model_path: str,
192412
cls: PretrainedModel,

0 commit comments

Comments
 (0)