Skip to content

Commit 93d7746

Browse files
authored
Support safetensors loading for layerwise (#2047)
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>
1 parent 30e803d commit 93d7746

File tree

4 files changed

+94
-67
lines changed

4 files changed

+94
-67
lines changed

neural_compressor/torch/algorithms/layer_wise/utils.py

Lines changed: 33 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,12 @@
2121
import os
2222

2323
import torch
24-
from accelerate import init_empty_weights
2524
from accelerate.utils import set_module_tensor_to_device
26-
from transformers import AutoConfig, AutoModelForCausalLM
27-
from transformers.models.auto.auto_factory import _BaseAutoModelClass
25+
from safetensors import safe_open
2826

2927
from neural_compressor.common import options
3028
from neural_compressor.torch.algorithms.weight_only.modules import INCWeightOnlyLinear
29+
from neural_compressor.torch.utils.utility import dowload_hf_model, load_empty_model
3130

3231
from .load import load
3332

@@ -94,59 +93,6 @@ def get_named_children(model, pre=[]):
9493
return module_list
9594

9695

97-
def dowload_hf_model(repo_id, cache_dir=None, repo_type=None, revision=None): # pragma: no cover
98-
"""Download hugging face model from hf hub."""
99-
from huggingface_hub.constants import DEFAULT_REVISION, HUGGINGFACE_HUB_CACHE
100-
from huggingface_hub.file_download import REGEX_COMMIT_HASH, repo_folder_name
101-
from huggingface_hub.utils import EntryNotFoundError
102-
103-
if cache_dir is None:
104-
cache_dir = HUGGINGFACE_HUB_CACHE
105-
if revision is None:
106-
revision = DEFAULT_REVISION
107-
if repo_type is None:
108-
repo_type = "model"
109-
storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type))
110-
commit_hash = None
111-
if REGEX_COMMIT_HASH.match(revision):
112-
commit_hash = revision
113-
else:
114-
ref_path = os.path.join(storage_folder, "refs", revision)
115-
if os.path.exists(ref_path):
116-
with open(ref_path) as f:
117-
commit_hash = f.read()
118-
if storage_folder and commit_hash:
119-
pointer_path = os.path.join(storage_folder, "snapshots", commit_hash)
120-
if os.path.isdir(pointer_path):
121-
return pointer_path
122-
else: # pragma: no cover
123-
from huggingface_hub import snapshot_download
124-
125-
file_path = snapshot_download(repo_id)
126-
return file_path
127-
128-
129-
def load_empty_model(pretrained_model_name_or_path, cls=AutoModelForCausalLM, **kwargs): # pragma: no cover
130-
"""Load a empty model."""
131-
is_local = os.path.isdir(pretrained_model_name_or_path)
132-
if is_local: # pragma: no cover
133-
path = pretrained_model_name_or_path
134-
else:
135-
path = dowload_hf_model(pretrained_model_name_or_path)
136-
if cls.__base__ == _BaseAutoModelClass:
137-
config = AutoConfig.from_pretrained(path, **kwargs)
138-
with init_empty_weights():
139-
model = cls.from_config(config)
140-
else: # pragma: no cover
141-
config = cls.config_class.from_pretrained(path, **kwargs)
142-
with init_empty_weights():
143-
model = cls(config)
144-
model.tie_weights()
145-
model.eval()
146-
model.path = pretrained_model_name_or_path
147-
return model
148-
149-
15096
def get_super_module_by_name(model, module_name):
15197
"""Get the father module with given name of child module."""
15298
name_list = module_name.split(".")
@@ -211,6 +157,27 @@ def load_tensor(path, tensor_name=None, prefix=None):
211157
return state_dict
212158

213159

160+
def load_tensor_from_safetensors(path, tensor_name=None, device="cpu"):
161+
"""Load a tensor from safetensors file with given tensor name."""
162+
with safe_open(path, framework="pt", device=device) as f:
163+
value = f.get_tensor(tensor_name)
164+
return value
165+
166+
167+
def load_tensor_from_safetensors_shard(
168+
pretrained_model_name_or_path, tensor_name, prefix=None, device="cpu"
169+
): # pragma: no cover
170+
"""Load tensor from shard."""
171+
path = _get_path(pretrained_model_name_or_path)
172+
idx_dict = json.load(open(os.path.join(path, "model.safetensors.index.json"), "r"))["weight_map"]
173+
if tensor_name not in idx_dict.keys():
174+
if tensor_name.replace(f"{prefix}.", "") in idx_dict.keys():
175+
tensor_name = tensor_name.replace(f"{prefix}.", "")
176+
else:
177+
assert False, "{} not in the index.json".format(tensor_name)
178+
return load_tensor_from_safetensors(os.path.join(path, idx_dict[tensor_name]), tensor_name, device)
179+
180+
214181
def _get_path(pretrained_model_name_or_path):
215182
is_local = os.path.isdir(pretrained_model_name_or_path)
216183
if is_local: # pragma: no cover
@@ -223,13 +190,14 @@ def _get_path(pretrained_model_name_or_path):
223190
get_path = _get_path
224191

225192

226-
def load_value(model, param_name, path):
193+
def load_value(model, param_name, path, device="cpu"):
227194
"""Load the module value.
228195
229196
Args:
230197
model (torch.nn.module): torch model.
231198
param_name (str): module name.
232199
path (str): path to load state_dict per layer.
200+
device (str, optional): module device. Defaults to "cpu".
233201
234202
Returns:
235203
tensor: the module value.
@@ -241,7 +209,13 @@ def load_value(model, param_name, path):
241209
if module == input_embeddings:
242210
param_name = name + "." + param_name.split(".")[-1]
243211
prefix = model.base_model_prefix
244-
if "pytorch_model.bin.index.json" in os.listdir(path):
212+
files = os.listdir(path)
213+
safetensors_files = [filename for filename in files if filename.endswith(".safetensors")]
214+
if len(safetensors_files) == 1:
215+
value = load_tensor_from_safetensors(os.path.join(path, "model.safetensors"), param_name, device=device)
216+
elif len(safetensors_files) >= 2:
217+
value = load_tensor_from_safetensors_shard(path, param_name, device=device)
218+
elif "pytorch_model.bin.index.json" in files:
245219
value = load_tensor_from_shard(path, param_name, prefix)
246220
else:
247221
value = load_tensor(os.path.join(path, "pytorch_model.bin"), param_name, prefix)
@@ -260,7 +234,7 @@ def load_module(model, module_name, path, device="cpu"):
260234
module = get_module(model, module_name)
261235
for n, p in module.named_parameters():
262236
param_name = module_name + "." + n
263-
value = load_value(model, param_name, path)
237+
value = load_value(model, param_name, path, device)
264238
set_module_tensor_to_device(model, param_name, device, value)
265239

266240

neural_compressor/torch/utils/utility.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -331,13 +331,24 @@ def dowload_hf_model(repo_id, cache_dir=None, repo_type=None, revision=None):
331331
commit_hash = f.read()
332332
if storage_folder and commit_hash:
333333
pointer_path = os.path.join(storage_folder, "snapshots", commit_hash)
334-
if os.path.isdir(pointer_path):
334+
if os.path.isdir(pointer_path) and any(
335+
file.endswith(".bin") or file.endswith(".safetensors") for file in os.listdir(pointer_path)
336+
):
335337
return pointer_path
336-
else: # pragma: no cover
337-
from huggingface_hub import snapshot_download
338+
from huggingface_hub import list_repo_files, snapshot_download
339+
340+
files_info = list_repo_files(repo_id)
341+
ignore_patterns = (
342+
["*.bin", "*.bin.index.json"]
343+
if (
344+
any(file for file in files_info if file.endswith(".bin"))
345+
and any(file for file in files_info if file.endswith(".safetensors"))
346+
)
347+
else None
348+
)
338349

339-
file_path = snapshot_download(repo_id)
340-
return file_path
350+
file_path = snapshot_download(repo_id, ignore_patterns=ignore_patterns)
351+
return file_path
341352

342353

343354
def load_empty_model(pretrained_model_name_or_path, cls=None, **kwargs):

neural_compressor/transformers/models/modeling_auto.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
157157
has_remote_code,
158158
)
159159

160-
model = load_empty_model(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
160+
model = load_empty_model(
161+
pretrained_model_name_or_path,
162+
trust_remote_code=trust_remote_code,
163+
)
161164
if use_cpu:
162165
quantization_config.post_init_cpu()
163166
elif use_xpu:

test/3x/torch/quantization/weight_only/test_transfomers.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def test_use_layer_wise(self):
122122
dummy_input = fp32_model.dummy_inputs["input_ids"]
123123

124124
# RTN
125-
# use_layer_wise=True
125+
# Case1: use_layer_wise=True
126126
woq_config = RtnConfig(bits=4, group_size=16, use_layer_wise=True)
127127
woq_model = AutoModelForCausalLM.from_pretrained(
128128
model_name_or_path,
@@ -139,7 +139,7 @@ def test_use_layer_wise(self):
139139
loaded_output = loaded_model(dummy_input)[0]
140140
assert torch.equal(woq_output, loaded_output), "loaded output should be same. Please double check."
141141

142-
# use_layer_wise=False
142+
# Case2: use_layer_wise=False
143143
woq_config = RtnConfig(bits=4, group_size=16, use_layer_wise=False)
144144
woq_model = AutoModelForCausalLM.from_pretrained(
145145
model_name_or_path,
@@ -148,6 +148,45 @@ def test_use_layer_wise(self):
148148
woq_output2 = woq_model(dummy_input)[0]
149149
assert torch.equal(woq_output, woq_output2), "use_layer_wise output should be same. Please double check."
150150

151+
# Case3: test safetensors model file
152+
from neural_compressor.torch.algorithms.layer_wise.utils import get_path
153+
154+
model_path = get_path(model_name_or_path)
155+
from transformers import AutoModelForCausalLM as RawAutoModelForCausalLM
156+
157+
ori_model = RawAutoModelForCausalLM.from_pretrained(model_name_or_path)
158+
# test 1 safetensors file
159+
ori_model.save_pretrained(model_path, safe_serialization=True)
160+
woq_config = RtnConfig(bits=4, group_size=16, use_layer_wise=True)
161+
162+
woq_model = AutoModelForCausalLM.from_pretrained(
163+
model_name_or_path,
164+
quantization_config=woq_config,
165+
)
166+
woq_output_1_safetensors = woq_model(dummy_input)[0]
167+
assert torch.equal(woq_output, woq_output_1_safetensors)
168+
169+
# test 3 safetensors files
170+
ori_model.save_pretrained(model_path, safe_serialization=True, max_shard_size="250KB")
171+
woq_config = RtnConfig(bits=4, group_size=16, use_layer_wise=True)
172+
woq_model = AutoModelForCausalLM.from_pretrained(
173+
model_name_or_path,
174+
quantization_config=woq_config,
175+
)
176+
woq_output_3_safetensors = woq_model(dummy_input)[0]
177+
assert torch.equal(woq_output, woq_output_3_safetensors)
178+
179+
# case4: test dowload_hf_model
180+
shutil.rmtree(model_path, ignore_errors=True)
181+
woq_config = RtnConfig(bits=4, group_size=16, use_layer_wise=True)
182+
183+
woq_model = AutoModelForCausalLM.from_pretrained(
184+
model_name_or_path,
185+
quantization_config=woq_config,
186+
)
187+
woq_output_download = woq_model(dummy_input)[0]
188+
assert torch.equal(woq_output_download, woq_output)
189+
151190
def test_loading_autoawq_model(self):
152191
user_model = AutoModelForCausalLM.from_pretrained(self.autoawq_model)
153192
tokenizer = AutoTokenizer.from_pretrained(self.autoawq_model)

0 commit comments

Comments
 (0)