14
14
# limitations under the License.
15
15
"""
16
16
17
+ import concurrent
18
+ import concurrent .futures
19
+ import contextlib
17
20
import json
18
21
import os
22
+ import re
23
+ from typing import Union
19
24
25
+ import numpy as np
20
26
import paddle
21
27
import paddle .distributed as dist
22
28
from fastsafetensors import SafeTensorsFileLoader , SingleGroup
23
29
from paddleformers .transformers import PretrainedModel
24
- from paddleformers .transformers .model_utils import load_tp_checkpoint
25
30
from safetensors import safe_open
26
31
from tqdm import tqdm
27
32
@@ -78,7 +83,7 @@ def load_ep_checkpoint(model_path: str,
78
83
desc = "Loading safetensor files" ,
79
84
unit = "file" ):
80
85
with safe_open (os .path .join (model_path , safetensor_path ),
81
- framework = "np " ,
86
+ framework = "pp " ,
82
87
device = "cpu" ) as f :
83
88
# Check if this file contains keys from filtered_map
84
89
for k in filtered_map :
@@ -92,15 +97,15 @@ def load_ep_checkpoint(model_path: str,
92
97
return state_dict
93
98
94
99
95
- def safetensors_weights_iterator (safe_tensor_list : list [str ], ):
100
+ def safetensors_weights_iterator (safe_tensor_list : list [str ] ):
96
101
"""
97
102
safetensors_weights_iterator
98
103
"""
99
104
for st_file in tqdm (
100
105
safe_tensor_list ,
101
106
desc = "Loading safetensors checkpoint shards" ,
102
107
):
103
- with safe_open (st_file , framework = "np " ) as f :
108
+ with safe_open (st_file , framework = "pp " ) as f :
104
109
for name in f .keys ():
105
110
param = f .get_tensor (name )
106
111
yield name , param
@@ -170,7 +175,7 @@ def get_all_safetensors(model_path: str):
170
175
safe_model_path = os .path .join (model_path , "model.safetensors" )
171
176
if os .path .exists (safe_model_path ):
172
177
safetensor_list = [safe_model_path ]
173
- with safe_open (safe_model_path , framework = "np " , device = "cpu" ) as f :
178
+ with safe_open (safe_model_path , framework = "pp " , device = "cpu" ) as f :
174
179
key_name_list = f .keys ()
175
180
return key_name_list , safetensor_list
176
181
else :
@@ -187,6 +192,221 @@ def get_all_safetensors(model_path: str):
187
192
return key_name_list , safetensor_list
188
193
189
194
195
+
196
+ def _add_variant (weights_name : str , variant = None ) -> str :
197
+ if variant is not None and len (variant ) > 0 :
198
+ splits = weights_name .split ("." )
199
+ splits = splits [:- 1 ] + [variant ] + splits [- 1 :]
200
+ weights_name = "." .join (splits )
201
+
202
+ return weights_name
203
+
204
+ @contextlib .contextmanager
205
+ def device_guard (device = "cpu" , dev_id = 0 ):
206
+ origin_device = paddle .device .get_device ()
207
+ if device == "cpu" :
208
+ paddle .set_device (device )
209
+ elif device in ["gpu" , "xpu" , "npu" ]:
210
+ paddle .set_device ("{}:{}" .format (device , dev_id ))
211
+ try :
212
+ yield
213
+ finally :
214
+ paddle .set_device (origin_device )
215
+
216
+ def _split_keys_evenly (keys : list , n : int ) -> list :
217
+
218
+ total_len = len (keys )
219
+ base_size = total_len // n
220
+ extra = total_len % n
221
+
222
+ result = []
223
+ index = 0
224
+ for _ in range (n ):
225
+ part_size = base_size + 1 if extra > 0 else base_size
226
+ extra -= 1
227
+ result .append (keys [index : index + part_size ])
228
+ index += part_size
229
+
230
+ return result
231
+
232
+ def load_sharded_checkpoint_as_one (folder , variant = None , return_numpy = False ):
233
+ pdparams_file = os .path .join (folder , _add_variant ("model_state.pdparams" , variant ))
234
+ lora_pdparams_file = os .path .join (folder , _add_variant ("lora_model_state.pdparams" , variant ))
235
+ safetensors_file = os .path .join (folder , _add_variant ("model.safetensors" , variant ))
236
+ if os .path .isfile (pdparams_file ):
237
+ return paddle .load (pdparams_file , return_numpy = return_numpy )
238
+ if os .path .isfile (lora_pdparams_file ):
239
+ return paddle .load (lora_pdparams_file , return_numpy = return_numpy )
240
+ if os .path .isfile (safetensors_file ):
241
+ state_dict = {}
242
+ with safe_open (safetensors_file , framework = "pp" ) as f :
243
+ for key in f .keys ():
244
+ state_dict [key ] = f .get_tensor ()
245
+ if not return_numpy :
246
+ for key in list (state_dict .keys ()):
247
+ if isinstance (state_dict [key ], np .ndarray ):
248
+ state_dict [key ] = paddle .Tensor .__call__ (state_dict .pop (key ), zero_copy = True )
249
+ return state_dict
250
+
251
+ PADDLE_WEIGHTS_INDEX_NAME = "model_state.pdparams.index.json"
252
+ SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
253
+ SAFE_MASTER_WEIGHTS_INDEX_NAME = "master_weights.safetensors.index.json"
254
+ SAFE_PEFT_WEIGHTS_INDEX_NAME = "peft_model.safetensors.index.json"
255
+
256
+ index_file = os .path .join (folder , _add_variant (PADDLE_WEIGHTS_INDEX_NAME , variant ))
257
+ safe_index_file = os .path .join (folder , _add_variant (SAFE_WEIGHTS_INDEX_NAME , variant ))
258
+ safe_master_file = os .path .join (folder , _add_variant (SAFE_MASTER_WEIGHTS_INDEX_NAME , variant ))
259
+ safe_peft_file = os .path .join (folder , _add_variant (SAFE_PEFT_WEIGHTS_INDEX_NAME , variant ))
260
+
261
+ index_present = os .path .isfile (index_file )
262
+ safe_index_present = os .path .isfile (safe_index_file )
263
+ safe_master_present = os .path .isfile (safe_master_file )
264
+ safe_peft_present = os .path .isfile (safe_peft_file )
265
+
266
+ load_index = None
267
+ if safe_index_present :
268
+ load_index = safe_index_file
269
+ elif safe_master_present :
270
+ load_index = safe_master_file
271
+ elif index_present :
272
+ load_index = index_file
273
+ elif safe_peft_present :
274
+ load_index = safe_peft_file
275
+ else :
276
+ raise ValueError (f"Could not find { index_file } or { safe_index_file } or { safe_peft_file } " )
277
+
278
+ with open (load_index , "r" , encoding = "utf-8" ) as f :
279
+ index = json .load (f )
280
+
281
+ shard_files = list (set (index ["weight_map" ].values ()))
282
+ ret = {}
283
+ for shard_file in tqdm (shard_files ):
284
+ with safe_open (os .path .join (folder , shard_file ), framework = "pp" , device = "cpu" ) as f :
285
+ for key in f .keys ():
286
+ ret [key ] = f .get_tensor (key )
287
+ if not return_numpy :
288
+ for key in list (ret .keys ()):
289
+ if isinstance (ret [key ], np .ndarray ):
290
+ ret [key ] = paddle .Tensor .__call__ (ret .pop (key ), zero_copy = True )
291
+ return ret
292
+
293
+ def _load_part_state_dict (
294
+ keys ,
295
+ checkpoint_file : Union [str , os .PathLike ],
296
+ tensor_parallel_split_mapping ,
297
+ fliter_dict_keys ,
298
+ return_numpy = False ,
299
+ ):
300
+ part_state_dict = {}
301
+ with safe_open (checkpoint_file , framework = "pp" ) as f :
302
+ for key in keys :
303
+ py_safe_slice_ = f .get_tensor (key )
304
+ if key in tensor_parallel_split_mapping :
305
+ weight = tensor_parallel_split_mapping [key ](py_safe_slice_ )
306
+ else :
307
+ weight = py_safe_slice_
308
+ if not return_numpy :
309
+ with device_guard ():
310
+ weight = paddle .Tensor .__call__ (weight , zero_copy = True )
311
+ weight = weight ._copy_to (paddle .framework ._current_expected_place (), False )
312
+ part_state_dict [key ] = weight
313
+ return part_state_dict
314
+
315
+ def load_tp_state_dict (checkpoint_file : Union [str , os .PathLike ],
316
+ tensor_parallel_split_mapping = None ,
317
+ fliter_dict_keys = None ,
318
+ device = "cpu" ,
319
+ return_numpy = False ):
320
+
321
+ if tensor_parallel_split_mapping is None :
322
+ tensor_parallel_split_mapping = {}
323
+
324
+ if (
325
+ checkpoint_file .endswith (".safetensors" ) or re .search (r"\.safetensors_shard_\d{4}$" , checkpoint_file )
326
+ ):
327
+ thread_num = int (os .environ .get ("LOAD_STATE_DICT_THREAD_NUM" , "1" ))
328
+ state_dict = {}
329
+ if thread_num <= 1 :
330
+ with safe_open (checkpoint_file , framework = "pp" ) as f :
331
+ state_dict = _load_part_state_dict (
332
+ list (f .keys ()),
333
+ checkpoint_file ,
334
+ tensor_parallel_split_mapping ,
335
+ fliter_dict_keys ,
336
+ return_numpy ,
337
+ )
338
+ else :
339
+ # Load state dict in multi-thread to speed up loading
340
+ with safe_open (checkpoint_file , framework = "pp" ) as f :
341
+ keys_groups = _split_keys_evenly (list (f .keys ()), thread_num )
342
+ with concurrent .futures .ThreadPoolExecutor (max_workers = thread_num ) as executor :
343
+ future_to_key = {
344
+ executor .submit (
345
+ _load_part_state_dict ,
346
+ keys ,
347
+ checkpoint_file ,
348
+ tensor_parallel_split_mapping ,
349
+ fliter_dict_keys ,
350
+ return_numpy ,
351
+ ): keys
352
+ for keys in keys_groups
353
+ }
354
+ for future in concurrent .futures .as_completed (future_to_key ):
355
+ res_state_dict = future .result ()
356
+ state_dict .update (res_state_dict )
357
+
358
+ if not return_numpy :
359
+ if device == "cpu" :
360
+ with device_guard ():
361
+ for k in list (state_dict .keys ()):
362
+ state_dict [k ] = paddle .Tensor .__call__ (state_dict .pop (k ), zero_copy = True )
363
+ elif device == "pin_memory" :
364
+ for k in list (state_dict .keys ()):
365
+ state_dict [k ] = paddle .to_tensor (state_dict .pop (k ), place = paddle .CUDAPinnedPlace ())
366
+
367
+ return state_dict
368
+
369
+
370
+ def load_tp_checkpoint (
371
+ folder : str ,
372
+ cls : PretrainedModel ,
373
+ config : ModelConfig ,
374
+ return_numpy : bool = True ,
375
+ ):
376
+ if config .tensor_parallel_degree == 1 or config .tensor_parallel_degree == - 1 :
377
+ return load_sharded_checkpoint_as_one (folder , return_numpy = return_numpy )
378
+ rank_model_path = os .path .join (folder , f"model_state.tp0{ config .tensor_parallel_rank } .pdparams" )
379
+ model_path = os .path .join (folder , "model_state.pdparams" )
380
+ safe_model_path = os .path .join (folder , "model.safetensors" )
381
+ if os .path .exists (rank_model_path ):
382
+ return paddle .load (rank_model_path , return_numpy = return_numpy )
383
+ elif os .path .exists (model_path ):
384
+ state_dict = cls .convert_tensor_parallel (model_path , config )
385
+ elif os .path .exists (safe_model_path ):
386
+ with safe_open (safe_model_path , framework = "pp" , device = "cpu" ) as f :
387
+ loaded_keys = f .keys ()
388
+ tp_actions = cls .get_tensor_parallel_convert_actions (config , loaded_keys )
389
+ state_dict = load_tp_state_dict (safe_model_path , tp_actions , return_numpy = return_numpy )
390
+ else : # shard files safetensors
391
+ resolved_archive_file , resolved_sharded_files , sharded_metadata , is_sharded = cls ._resolve_model_file_path (
392
+ pretrained_model_name_or_path = folder ,
393
+ use_safetensors = True ,
394
+ )
395
+ if len (resolved_sharded_files ) > 1 :
396
+ resolved_sharded_files = tqdm (resolved_sharded_files , desc = "Loading checkpoint shards" )
397
+ loaded_state_dict_keys = sharded_metadata ["all_checkpoint_keys" ]
398
+ tp_actions = cls .get_tensor_parallel_convert_actions (config , loaded_state_dict_keys , ignore_error = True )
399
+ state_dict = {}
400
+ for shard_file in resolved_sharded_files :
401
+ shard_state_dict = load_tp_state_dict ( # todo: for this function
402
+ shard_file ,
403
+ tp_actions ,
404
+ loaded_state_dict_keys ,
405
+ return_numpy = return_numpy ,
406
+ )
407
+ state_dict .update (shard_state_dict )
408
+ return state_dict
409
+
190
410
def load_tp_checkpoint_v1 (
191
411
model_path : str ,
192
412
cls : PretrainedModel ,
0 commit comments