21
21
import os
22
22
23
23
import torch
24
- from accelerate import init_empty_weights
25
24
from accelerate .utils import set_module_tensor_to_device
26
- from transformers import AutoConfig , AutoModelForCausalLM
27
- from transformers .models .auto .auto_factory import _BaseAutoModelClass
25
+ from safetensors import safe_open
28
26
29
27
from neural_compressor .common import options
30
28
from neural_compressor .torch .algorithms .weight_only .modules import INCWeightOnlyLinear
29
+ from neural_compressor .torch .utils .utility import dowload_hf_model , load_empty_model
31
30
32
31
from .load import load
33
32
@@ -94,59 +93,6 @@ def get_named_children(model, pre=[]):
94
93
return module_list
95
94
96
95
97
- def dowload_hf_model (repo_id , cache_dir = None , repo_type = None , revision = None ): # pragma: no cover
98
- """Download hugging face model from hf hub."""
99
- from huggingface_hub .constants import DEFAULT_REVISION , HUGGINGFACE_HUB_CACHE
100
- from huggingface_hub .file_download import REGEX_COMMIT_HASH , repo_folder_name
101
- from huggingface_hub .utils import EntryNotFoundError
102
-
103
- if cache_dir is None :
104
- cache_dir = HUGGINGFACE_HUB_CACHE
105
- if revision is None :
106
- revision = DEFAULT_REVISION
107
- if repo_type is None :
108
- repo_type = "model"
109
- storage_folder = os .path .join (cache_dir , repo_folder_name (repo_id = repo_id , repo_type = repo_type ))
110
- commit_hash = None
111
- if REGEX_COMMIT_HASH .match (revision ):
112
- commit_hash = revision
113
- else :
114
- ref_path = os .path .join (storage_folder , "refs" , revision )
115
- if os .path .exists (ref_path ):
116
- with open (ref_path ) as f :
117
- commit_hash = f .read ()
118
- if storage_folder and commit_hash :
119
- pointer_path = os .path .join (storage_folder , "snapshots" , commit_hash )
120
- if os .path .isdir (pointer_path ):
121
- return pointer_path
122
- else : # pragma: no cover
123
- from huggingface_hub import snapshot_download
124
-
125
- file_path = snapshot_download (repo_id )
126
- return file_path
127
-
128
-
129
- def load_empty_model (pretrained_model_name_or_path , cls = AutoModelForCausalLM , ** kwargs ): # pragma: no cover
130
- """Load a empty model."""
131
- is_local = os .path .isdir (pretrained_model_name_or_path )
132
- if is_local : # pragma: no cover
133
- path = pretrained_model_name_or_path
134
- else :
135
- path = dowload_hf_model (pretrained_model_name_or_path )
136
- if cls .__base__ == _BaseAutoModelClass :
137
- config = AutoConfig .from_pretrained (path , ** kwargs )
138
- with init_empty_weights ():
139
- model = cls .from_config (config )
140
- else : # pragma: no cover
141
- config = cls .config_class .from_pretrained (path , ** kwargs )
142
- with init_empty_weights ():
143
- model = cls (config )
144
- model .tie_weights ()
145
- model .eval ()
146
- model .path = pretrained_model_name_or_path
147
- return model
148
-
149
-
150
96
def get_super_module_by_name (model , module_name ):
151
97
"""Get the father module with given name of child module."""
152
98
name_list = module_name .split ("." )
@@ -211,6 +157,27 @@ def load_tensor(path, tensor_name=None, prefix=None):
211
157
return state_dict
212
158
213
159
160
+ def load_tensor_from_safetensors (path , tensor_name = None , device = "cpu" ):
161
+ """Load a tensor from safetensors file with given tensor name."""
162
+ with safe_open (path , framework = "pt" , device = device ) as f :
163
+ value = f .get_tensor (tensor_name )
164
+ return value
165
+
166
+
167
+ def load_tensor_from_safetensors_shard (
168
+ pretrained_model_name_or_path , tensor_name , prefix = None , device = "cpu"
169
+ ): # pragma: no cover
170
+ """Load tensor from shard."""
171
+ path = _get_path (pretrained_model_name_or_path )
172
+ idx_dict = json .load (open (os .path .join (path , "model.safetensors.index.json" ), "r" ))["weight_map" ]
173
+ if tensor_name not in idx_dict .keys ():
174
+ if tensor_name .replace (f"{ prefix } ." , "" ) in idx_dict .keys ():
175
+ tensor_name = tensor_name .replace (f"{ prefix } ." , "" )
176
+ else :
177
+ assert False , "{} not in the index.json" .format (tensor_name )
178
+ return load_tensor_from_safetensors (os .path .join (path , idx_dict [tensor_name ]), tensor_name , device )
179
+
180
+
214
181
def _get_path (pretrained_model_name_or_path ):
215
182
is_local = os .path .isdir (pretrained_model_name_or_path )
216
183
if is_local : # pragma: no cover
@@ -223,13 +190,14 @@ def _get_path(pretrained_model_name_or_path):
223
190
get_path = _get_path
224
191
225
192
226
- def load_value (model , param_name , path ):
193
+ def load_value (model , param_name , path , device = "cpu" ):
227
194
"""Load the module value.
228
195
229
196
Args:
230
197
model (torch.nn.module): torch model.
231
198
param_name (str): module name.
232
199
path (str): path to load state_dict per layer.
200
+ device (str, optional): module device. Defaults to "cpu".
233
201
234
202
Returns:
235
203
tensor: the module value.
@@ -241,7 +209,13 @@ def load_value(model, param_name, path):
241
209
if module == input_embeddings :
242
210
param_name = name + "." + param_name .split ("." )[- 1 ]
243
211
prefix = model .base_model_prefix
244
- if "pytorch_model.bin.index.json" in os .listdir (path ):
212
+ files = os .listdir (path )
213
+ safetensors_files = [filename for filename in files if filename .endswith (".safetensors" )]
214
+ if len (safetensors_files ) == 1 :
215
+ value = load_tensor_from_safetensors (os .path .join (path , "model.safetensors" ), param_name , device = device )
216
+ elif len (safetensors_files ) >= 2 :
217
+ value = load_tensor_from_safetensors_shard (path , param_name , device = device )
218
+ elif "pytorch_model.bin.index.json" in files :
245
219
value = load_tensor_from_shard (path , param_name , prefix )
246
220
else :
247
221
value = load_tensor (os .path .join (path , "pytorch_model.bin" ), param_name , prefix )
@@ -260,7 +234,7 @@ def load_module(model, module_name, path, device="cpu"):
260
234
module = get_module (model , module_name )
261
235
for n , p in module .named_parameters ():
262
236
param_name = module_name + "." + n
263
- value = load_value (model , param_name , path )
237
+ value = load_value (model , param_name , path , device )
264
238
set_module_tensor_to_device (model , param_name , device , value )
265
239
266
240
0 commit comments