Skip to content

Commit f45b4a1

Browse files
committed
Move validate safetensors function
Add tests for the same
1 parent 343aa8d commit f45b4a1

File tree

5 files changed

+160
-51
lines changed

5 files changed

+160
-51
lines changed

src/compressed_tensors/utils/converters/converters.py

+46-40
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323
import torch
2424
from compressed_tensors.registry.registry import RegistryMixin
2525
from compressed_tensors.utils.converters.transformations import (
26+
remove_unused_tensors,
2627
transform_autogptq_weights_and_reshape_tensors,
2728
transform_exllama_names,
2829
)
30+
from compressed_tensors.utils.safetensors_load import validate_safetensors_file_path
2931
from safetensors import safe_open
3032
from safetensors.torch import save_file
3133
from tqdm import tqdm
@@ -38,7 +40,7 @@
3840

3941

4042
class ConverterNames(str, Enum):
41-
EXLLAMA_TO_COMPRESSED_TENSOR = "exllama_to_compressed_tensor"
43+
AutoGPTQConverter: str = "exllama_to_compressed_tensor"
4244

4345

4446
class BaseConverter(ABC, RegistryMixin):
@@ -71,7 +73,7 @@ def convert_from_safetensors(
7173
:param save_dir: The directory to save the converted state_dict to
7274
:return: The directory where the converted state_dict was saved
7375
"""
74-
_validate_safetensors_file_path(filepath)
76+
validate_safetensors_file_path(filepath)
7577

7678
filepath_: Path = Path(filepath)
7779
if not save_dir:
@@ -84,30 +86,42 @@ def convert_from_safetensors(
8486
# transform and save the state_dict
8587
if filepath_.is_dir():
8688
tqdm.write(f"Converting directory: {filepath}")
87-
tqdm.write(f"Found: {len(list(filepath_.glob('*.safetensors')))} .safetensors files")
89+
tqdm.write(
90+
f"Found: {len(list(filepath_.glob('*.safetensors')))} "
91+
".safetensors files"
92+
)
8893
for file in filepath_.glob("*.safetensors"):
8994
tqdm.write(f"Converting file: {file.name}")
9095
new_state_dict = {}
9196
state_dict: Iterable[StateDictType] = load_safetensors_state_dict(
9297
file, by_layers=True
9398
)
94-
layer_progress_bar = tqdm(state_dict, total=layer_count(file), desc="Converting layers")
99+
layer_progress_bar = tqdm(
100+
state_dict, total=layer_count(file), desc="Converting layers"
101+
)
95102
for layer_state_dict in layer_progress_bar:
96-
layer_name = list(layer_state_dict.keys())[0][:len("model.layers.0")]
103+
layer_name = list(layer_state_dict.keys())[0][
104+
: len("model.layers.0")
105+
]
97106
layer_progress_bar.set_description(f"Converting layer {layer_name}")
98107
layer_progress_bar.update()
99108
new_state_dict.update(
100109
cls.translate(state_dict=layer_state_dict, **kwargs)
101110
)
102111

103112
if new_state_dict:
113+
# compress before saving
114+
# compressor = Compressor.load_from_registry(
115+
# name=CompressionFormat.pack_quantized.value
116+
# )
117+
# new_state_dict = compressor.compress(new_state_dict)
104118
save_file(
105119
new_state_dict,
106120
filename=save_dir_ / file.name,
107121
metadata=metadata,
108122
)
109123
_copy_non_safetensor_files_(filepath_, save_dir_)
110-
_update_quantization_config(filepath_, save_dir_)
124+
# _update_quantization_config(filepath_, save_dir_)
111125

112126
elif filepath_.is_file():
113127
new_state_dict = {}
@@ -134,39 +148,28 @@ def transformations(cls) -> Iterable[TransformationType]:
134148
raise NotImplementedError()
135149

136150

137-
@BaseConverter.register(name=ConverterNames.EXLLAMA_TO_COMPRESSED_TENSOR)
138-
class ExllamaToCompressedTensorConverter(BaseConverter):
151+
@BaseConverter.register(name=ConverterNames.AutoGPTQConverter)
152+
class AutoGPTQConverter(BaseConverter):
139153
"""
140154
A converter that applies transformations to the state_dict of a autogptq
141-
quantized model to convert it to a compressed tensor model, which can be
142-
loaded by the SparseAutoModel classes
143-
"""
144-
145-
@classmethod
146-
def transformations(cls):
147-
return (transform_autogptq_weights_and_reshape_tensors, transform_exllama_names)
155+
quantized model to convert it to a compressed tensor model
148156
157+
Transformations made:
149158
150-
def _validate_safetensors_file_path(filepath: str):
151-
"""
152-
Given a file path, it is valid if:
153-
- The file exists
154-
- The file is either a single .safetensors file or a
155-
directory containing .safetensors files
156-
157-
:param filepath: A string file path to validate
159+
-> Unpack autogptq 4 bit weight packing
160+
-> Translate exllama names to compressed tensor names
161+
-> Pack 4 bit weights with compressed tensor format
162+
-> Remove unused tensors
163+
-> Update quantization config in config.json file
158164
"""
159165

160-
filepath_: Path = Path(filepath)
161-
162-
if not filepath_.exists():
163-
raise FileNotFoundError(f"File not found: {filepath}")
164-
165-
if filepath_.is_dir() and not any(filepath_.glob("*.safetensors")):
166-
raise FileNotFoundError(f"No .safetensors files found in directory: {filepath}")
167-
168-
if filepath_.is_file() and not filepath_.suffix == ".safetensors":
169-
raise ValueError(f"File must be a .safetensors file: {filepath}")
166+
@classmethod
167+
def transformations(cls):
168+
return (
169+
transform_autogptq_weights_and_reshape_tensors,
170+
transform_exllama_names,
171+
remove_unused_tensors,
172+
)
170173

171174

172175
def _copy_non_safetensor_files_(source_dir: Path, dest_dir: Path):
@@ -178,7 +181,7 @@ def _copy_non_safetensor_files_(source_dir: Path, dest_dir: Path):
178181
:param dest_dir: The directory to copy files to
179182
"""
180183
for file in source_dir.glob("*"):
181-
if file.suffix != ".safetensors":
184+
if file.suffix != ".safetensors" and file.name != "config.json":
182185
_LOGGER.info(f"Copying file: {file} to {dest_dir}")
183186
shutil.copy(file, dest_dir / file.name)
184187

@@ -198,7 +201,9 @@ def _update_quantization_config(source_dir: Path, dest_dir: Path):
198201
if hasattr(config, "quantization_config"):
199202
_LOGGER.info("Updating quantization config...")
200203
quantization_config = config.quantization_config
201-
config.quantization_config = _convert_to_compressed_tensors_config(quantization_config)
204+
config.quantization_config = _convert_to_compressed_tensors_config(
205+
quantization_config
206+
)
202207
config.save_pretrained(dest_dir)
203208

204209

@@ -207,12 +212,14 @@ def _convert_to_compressed_tensors_config(quantization_config):
207212
Converts the quantization_config attribute from a config.json file
208213
to a dictionary
209214
210-
:param quantization_config: The quantization_config attribute from a config.json file
215+
:param quantization_config: The quantization_config
216+
attribute from a config.json file
211217
:return: The quantization_config as a dictionary
212218
"""
213219
compressed_tensor_config = ...
214220
return compressed_tensor_config
215221

222+
216223
def layer_count(file_path: str) -> int:
217224
"""
218225
Count the number of layers in a safetensors file
@@ -222,16 +229,15 @@ def layer_count(file_path: str) -> int:
222229
"""
223230
with safe_open(file_path, framework="pt", device="cpu") as f:
224231
keys = sorted(f.keys())
225-
232+
226233
last_layer_name = None
227234
layer_count = 0
228235
for key in keys:
229-
layer_name = key[:len("model.layers.0")]
236+
layer_name = key[: len("model.layers.0")]
230237
if layer_name != last_layer_name:
231238
last_layer_name = layer_name
232239
layer_count += 1
233240
return layer_count
234-
235241

236242

237243
def load_safetensors_state_dict(
@@ -251,7 +257,7 @@ def load_safetensors_state_dict(
251257
current_layer = None
252258
layer_data = {}
253259
for key in sorted(f.keys()):
254-
layer_name = key[:len("model.layers.0")]
260+
layer_name = key[: len("model.layers.0")]
255261
if current_layer is None:
256262
current_layer = layer_name
257263
elif layer_name != current_layer:

src/compressed_tensors/utils/converters/main.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515

1616
from compressed_tensors.utils.converters.converters import BaseConverter, ConverterNames
1717

18+
1819
__all__ = ["convert_autogptq_checkpoint"]
1920

2021

2122
def convert_autogptq_checkpoint(
22-
old_checkpoint_path, new_checkpoint_path ,**kwargs
23+
old_checkpoint_path, new_checkpoint_path, **kwargs
2324
) -> str:
2425
"""
2526
Convert an autogptq checkpoint to a compressed tensor checkpoint
@@ -31,7 +32,7 @@ def convert_autogptq_checkpoint(
3132
:return: the path to the new checkpoint
3233
"""
3334
converter: BaseConverter = BaseConverter.load_from_registry(
34-
ConverterNames.EXLLAMA_TO_COMPRESSED_TENSOR
35+
ConverterNames.AutoGPTQConverter
3536
)
3637
checkpoint_path = converter.convert_from_safetensors(
3738
old_checkpoint_path, new_checkpoint_path, **kwargs

src/compressed_tensors/utils/converters/transformations.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def transform_exllama_names(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]:
6969
name_map: Dict[str, str] = {
7070
".scales": ".weight_scale",
7171
".qzeros": ".weight_zero_point",
72-
".qweight": ".weight",
72+
".qweight": ".weight_packed",
7373
}
7474

7575
updated_state_dict = {}
@@ -91,7 +91,7 @@ def transform_autogptq_weights_and_reshape_tensors(
9191
to CompressedTensors conversion
9292
9393
The transformations include:
94-
- Unpack ad dequantize the weight tensor using the scales, zeros, and g_idx tensors
94+
- Unpack and dequantize the weight tensor using the scales, zeros, and g_idx tensors
9595
- Squeeze the scales tensor to [x] from [1, x]
9696
9797
:pre-condition: The state_dict should be for a quantized model
@@ -117,13 +117,15 @@ def transform_autogptq_weights_and_reshape_tensors(
117117
g_idx = state_dict[key.replace("qweight", "g_idx")]
118118

119119
zeros = unpack_zeros(qzeros)
120-
qweight = unpack_int32_into_fp32(
121-
qweight=tensor,
122-
scales=scales,
123-
zeros=zeros,
124-
g_idx=g_idx,
125-
)
126-
transformed_weights_dict[key] = qweight
120+
# qweight = unpack_int32_into_fp32(
121+
# qweight=tensor,
122+
# scales=scales,
123+
# zeros=zeros,
124+
# g_idx=g_idx,
125+
# )
126+
new_shape = torch.tensor([tensor.shape[0] * 8, tensor.shape[1]])
127+
transformed_weights_dict[key] = tensor
128+
transformed_weights_dict[key.replace("qweight", "weight_shape")] = new_shape
127129

128130
# transform scales
129131
for key, tensor in state_dict.items():
@@ -222,3 +224,17 @@ def unpack_int32_into_fp32(
222224
weight = torch.cat(weight, dim=1)
223225

224226
return weight
227+
228+
229+
def remove_unused_tensors(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]:
230+
"""
231+
Remove unused tensors from the state_dict
232+
233+
:param state_dict: The state_dict to be cleaned
234+
:return: The cleaned state_dict
235+
"""
236+
return {
237+
key: tensor
238+
for key, tensor in state_dict.items()
239+
if is_gptq_quantization_target(key) and not key.endswith(".g_idx")
240+
}

src/compressed_tensors/utils/safetensors_load.py

+23
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import re
1818
import struct
19+
from pathlib import Path
1920
from typing import Dict, List, Optional
2021

2122
from safetensors import safe_open
@@ -236,3 +237,25 @@ def is_quantization_param(name: str) -> bool:
236237
return True
237238

238239
return False
240+
241+
242+
def validate_safetensors_file_path(filepath: str):
243+
"""
244+
Given a file path, it is valid if:
245+
- The file exists
246+
- The file is either a single .safetensors file or a
247+
directory containing .safetensors files
248+
249+
:param filepath: A string file path to validate
250+
"""
251+
252+
filepath_: Path = Path(filepath)
253+
254+
if not filepath_.exists():
255+
raise FileNotFoundError(f"File not found: {filepath}")
256+
257+
if filepath_.is_dir() and not any(filepath_.glob("*.safetensors")):
258+
raise FileNotFoundError(f"No .safetensors files found in directory: {filepath}")
259+
260+
if filepath_.is_file() and not filepath_.suffix == ".safetensors":
261+
raise ValueError(f"File must be a .safetensors file: {filepath}")
+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
from compressed_tensors.utils.safetensors_load import validate_safetensors_file_path
17+
18+
19+
@pytest.fixture
20+
def temp_dir(tmp_path):
21+
return tmp_path / "subdirectory"
22+
23+
24+
@pytest.fixture
25+
def safetensors_file(temp_dir):
26+
temp_dir.mkdir(exists_ok=True)
27+
safetensors_filepath = temp_dir / "test.safetensors"
28+
safetensors_filepath.write_text("content")
29+
return safetensors_filepath
30+
31+
32+
@pytest.fixture
33+
def non_safetensors_file(temp_dir):
34+
temp_dir.mkdir(exists_ok=True)
35+
non_safetensors_filepath = temp_dir / "test.txt"
36+
non_safetensors_filepath.write_text("content")
37+
return non_safetensors_filepath
38+
39+
40+
def test_validate_safetensors_file_path_file_not_found():
41+
with pytest.raises(FileNotFoundError):
42+
validate_safetensors_file_path("nonexistent_file.safetensors")
43+
44+
45+
def test_validate_safetensors_file_path_no_safetensors_files_in_directory(temp_dir):
46+
temp_dir.mkdir()
47+
with pytest.raises(FileNotFoundError):
48+
validate_safetensors_file_path(str(temp_dir))
49+
50+
51+
def test_validate_safetensors_file_path_file_is_not_safetensors(non_safetensors_file):
52+
with pytest.raises(ValueError):
53+
validate_safetensors_file_path(str(non_safetensors_file))
54+
55+
56+
def test_validate_safetensors_file_path_valid_safetensors_file(safetensors_file):
57+
validate_safetensors_file_path(str(safetensors_file))
58+
59+
60+
def test_validate_safetensors_file_path_valid_directory_with_safetensors_files(
61+
temp_dir, safetensors_file
62+
):
63+
validate_safetensors_file_path(str(temp_dir))

0 commit comments

Comments
 (0)