Skip to content

Commit c1b06de

Browse files
committed
reduce memory requirements, clarify map_modules_to_quant_scheme
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 4574747 commit c1b06de

File tree

4 files changed

+103
-92
lines changed

4 files changed

+103
-92
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from compressed_tensors.quantization import (
3636
DEFAULT_QUANTIZATION_METHOD,
3737
QuantizationConfig,
38+
QuantizationScheme,
3839
QuantizationStatus,
3940
apply_quantization_config,
4041
load_pretrained_quantization,
@@ -61,7 +62,7 @@
6162
from transformers.file_utils import CONFIG_NAME
6263

6364

64-
__all__ = ["ModelCompressor", "map_modules_to_quant_args"]
65+
__all__ = ["ModelCompressor", "map_modules_to_quant_scheme"]
6566

6667
_LOGGER: logging.Logger = logging.getLogger(__name__)
6768

@@ -372,15 +373,13 @@ def compress(
372373
if state_dict is None:
373374
state_dict = model.state_dict()
374375

375-
compressed_state_dict = state_dict
376-
377-
quantized_modules_to_args: Dict[
378-
str, QuantizationArgs
379-
] = map_modules_to_quant_args(model)
376+
module_to_scheme: Dict[str, QuantizationScheme] = map_modules_to_quant_scheme(
377+
model
378+
)
380379

381380
if self.quantization_compressor is not None:
382-
compressed_state_dict = self.quantization_compressor.compress(
383-
state_dict, names_to_scheme=quantized_modules_to_args
381+
state_dict = self.quantization_compressor.compress(
382+
state_dict, names_to_scheme=module_to_scheme
384383
)
385384
if self.quantization_config.format != CompressionFormat.dense.value:
386385
self.quantization_config.quantization_status = (
@@ -393,8 +392,8 @@ def compress(
393392
targets=self.sparsity_config.targets,
394393
ignore=self.sparsity_config.ignore,
395394
)
396-
compressed_state_dict = self.sparsity_compressor.compress(
397-
compressed_state_dict,
395+
state_dict = self.sparsity_compressor.compress(
396+
state_dict,
398397
compression_targets=sparse_compression_targets,
399398
)
400399

@@ -403,7 +402,7 @@ def compress(
403402
# https://github.com/huggingface/transformers/pull/30488
404403
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
405404

406-
return compressed_state_dict
405+
return state_dict
407406

408407
def decompress(self, model_path: str, model: Module):
409408
"""
@@ -522,9 +521,9 @@ def _replace_weights(self, dense_weight_generator, model: Module):
522521
update_parameter_data(module, data, param_name)
523522

524523

525-
def map_modules_to_quant_args(
524+
def map_modules_to_quant_scheme(
526525
model: Module,
527-
) -> Dict[str, Union[QuantizationArgs, Tuple[QuantizationArgs, QuantizationArgs]]]:
526+
) -> Dict[str, QuantizationScheme]:
528527
"""
529528
Given a pytorch model, map out the submodule name (usually linear layers)
530529
to the weight QuantizationArgs. If running input activation quantization, will also
@@ -535,15 +534,8 @@ def map_modules_to_quant_args(
535534
quantized_modules_to_args = {}
536535
for name, submodule in iter_named_leaf_modules(model):
537536
if is_module_quantized(submodule):
538-
if submodule.quantization_scheme.weights is not None:
539-
name = fix_fsdp_module_name(name)
540-
quantized_modules_to_args[name] = submodule.quantization_scheme.weights
541-
if submodule.quantization_scheme.input_activations is not None:
542-
weight_args = quantized_modules_to_args.get(name)
543-
quantized_modules_to_args[name] = (
544-
weight_args,
545-
submodule.quantization_scheme.input_activations,
546-
)
537+
name = fix_fsdp_module_name(name)
538+
quantized_modules_to_args[name] = submodule.quantization_scheme
547539

548540
return quantized_modules_to_args
549541

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 81 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@
1818

1919
import torch
2020
from compressed_tensors.compressors.base import BaseCompressor
21-
from compressed_tensors.quantization import QuantizationArgs
21+
from compressed_tensors.quantization import QuantizationScheme
2222
from compressed_tensors.utils import (
2323
get_nested_mappings_from_state_dict,
2424
get_nested_weight_mappings,
2525
merge_names,
26+
remove_suffix,
2627
)
2728
from safetensors import safe_open
2829
from torch import Tensor
@@ -69,95 +70,82 @@ class BaseQuantizationCompressor(BaseCompressor):
6970
def compress(
7071
self,
7172
model_state: Dict[str, Tensor],
72-
names_to_scheme: Dict[str, QuantizationArgs],
73+
names_to_scheme: Dict[str, QuantizationScheme],
7374
**kwargs,
7475
) -> Dict[str, Tensor]:
7576
"""
7677
Compresses a dense state dict
7778
78-
:param model_state: state dict of uncompressed model
79+
:param model_state: state dict of uncompressed model, consumed during
80+
compression
7981
:param names_to_scheme: quantization args for each quantized weight, needed for
8082
quantize function to calculate bit depth
8183
:return: compressed state dict
8284
"""
83-
compressed_dict = {}
84-
weight_suffix = ".weight"
85-
input_zp_suffix = ".input_zero_point"
86-
weight_zp_suffix = ".weight_zero_point"
87-
_LOGGER.debug(
88-
f"Compressing model with {len(model_state)} parameterized layers..."
89-
)
85+
save_device = "cpu"
86+
87+
uncompressed_names = list(model_state.keys())
88+
for name in tqdm(uncompressed_names, desc="Compressing with quantization"):
89+
value = model_state[name]
90+
91+
# compress weights
92+
if name.endswith(".weight"):
93+
prefix = remove_suffix(name, ".weight")
9094

91-
for name, value in tqdm(model_state.items(), desc="Quantized Compression"):
92-
# check if the parameter we're compressing is the weight zp
93-
# or the input zp
94-
is_weight_zp = name.endswith(weight_zp_suffix)
95-
is_input_zp = name.endswith(input_zp_suffix)
96-
97-
# if we're saving the weight zp, fetch weight quant args
98-
if is_weight_zp:
99-
quant_args_zp = names_to_scheme.get(name[: -(len(weight_zp_suffix))])
100-
if isinstance(quant_args_zp, tuple):
101-
# If tuple, first value is weight args, second is input args
102-
quant_args_zp = quant_args_zp[0]
103-
104-
# if we're saving the input zp, fetch input quant args
105-
if is_input_zp:
106-
input_args_zp = names_to_scheme.get(name[: -(len(input_zp_suffix))])
107-
if isinstance(input_args_zp, tuple):
108-
# If tuple, first value is weight args, second is input args
109-
input_args_zp = input_args_zp[-1]
110-
111-
if name.endswith(weight_suffix):
112-
prefix = name[: -(len(weight_suffix))]
95+
# gather qparams
11396
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
114-
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
11597
g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
116-
if scale is not None:
117-
# weight is quantized, compress it
118-
if isinstance(names_to_scheme[prefix], tuple):
119-
quant_args = names_to_scheme[prefix][0]
120-
else:
121-
quant_args = names_to_scheme[prefix]
122-
123-
compressed_data = self.compress_weight(
124-
weight=value,
125-
scale=scale,
126-
zero_point=zp,
127-
g_idx=g_idx,
128-
quantization_args=quant_args,
129-
device="cpu",
130-
)
131-
for key, value in compressed_data.items():
132-
compressed_dict[merge_names(prefix, key)] = value
133-
else:
134-
compressed_dict[name] = value.to("cpu")
135-
# only save if asym
136-
elif is_weight_zp and quant_args_zp.symmetric:
137-
continue
138-
# only save if asym
139-
elif is_input_zp and input_args_zp.symmetric:
140-
continue
141-
elif name.endswith("g_idx") and torch.any(value <= -1):
142-
continue
98+
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
99+
100+
# is scale does not exist, then weight cannot be compressed
101+
if scale is None:
102+
model_state[name] = value.to(save_device)
103+
continue
104+
105+
# compress values on cpu. TODO: experiment with different devices
106+
quant_args = names_to_scheme[prefix].weights
107+
compressed_values = self.compress_weight(
108+
weight=value,
109+
scale=scale,
110+
zero_point=zp,
111+
g_idx=g_idx,
112+
quantization_args=quant_args,
113+
device="cpu",
114+
)
115+
116+
# update state dict
117+
del model_state[name]
118+
for key, value in compressed_values.items():
119+
model_state[merge_names(prefix, key)] = value.to(save_device)
120+
143121
else:
144-
compressed_dict[name] = value.to("cpu")
122+
# omit saving zero points for symmetric quantization
123+
if name.endswith("zero_point") and _is_symmetric(name, names_to_scheme):
124+
del model_state[name]
145125

146-
return compressed_dict
126+
# omit saving for g_idx if uninitialized
127+
# TODO: does this case actually occur?
128+
elif name.endswith("g_idx") and torch.any(value <= -1):
129+
del model_state[name]
130+
131+
else:
132+
model_state[name] = value.to(save_device)
133+
134+
return model_state
147135

148136
def decompress(
149137
self,
150138
path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],
151-
names_to_scheme: Dict[str, QuantizationArgs],
152-
device: str = "cpu",
139+
names_to_scheme: Dict[str, QuantizationScheme],
140+
device: torch.device = "cpu",
153141
) -> Generator[Tuple[str, Tensor], None, None]:
154142
"""
155143
Reads a compressed state dict located at path_to_model_or_tensors
156144
and returns a generator for sequentially decompressing back to a
157145
dense state dict
158146
:param path_to_model_or_tensors: path to compressed safetensors model (directory
159147
with one or more safetensors files) or compressed tensors file
160-
:param names_to_scheme: quantization args for each quantized weight
148+
:param names_to_scheme: quantization scheme for each quantized weight
161149
:param device: optional device to load intermediate weights into
162150
:return: compressed state dict
163151
"""
@@ -171,7 +159,12 @@ def decompress(
171159
path_to_model_or_tensors, names_to_scheme
172160
)
173161

174-
def _decompress_from_path(self, path_to_model, names_to_scheme, device):
162+
def _decompress_from_path(
163+
self,
164+
path_to_model: Union[str, Path, Dict[str, Any]],
165+
names_to_scheme: Dict[str, QuantizationScheme],
166+
device: torch.device,
167+
):
175168
weight_mappings = get_nested_weight_mappings(
176169
path_to_model, self.compression_param_names
177170
)
@@ -182,13 +175,17 @@ def _decompress_from_path(self, path_to_model, names_to_scheme, device):
182175
with safe_open(safe_path, framework="pt", device=device) as f:
183176
weight_data[param_name] = f.get_tensor(full_name)
184177
if "weight_scale" in weight_data:
185-
quant_args = names_to_scheme[weight_name]
178+
quant_args = names_to_scheme[weight_name].weights
186179
decompressed = self.decompress_weight(
187180
compressed_data=weight_data, quantization_args=quant_args
188181
)
189182
yield merge_names(weight_name, "weight"), decompressed
190183

191-
def _decompress_from_state_dict(self, state_dict, names_to_scheme):
184+
def _decompress_from_state_dict(
185+
self,
186+
state_dict: Dict[str, torch.Tensor],
187+
names_to_scheme: Dict[str, QuantizationScheme],
188+
):
192189
weight_mappings = get_nested_mappings_from_state_dict(
193190
state_dict, self.compression_param_names
194191
)
@@ -198,8 +195,23 @@ def _decompress_from_state_dict(self, state_dict, names_to_scheme):
198195
weight_data[param_name] = param_value
199196

200197
if "weight_scale" in weight_data:
201-
quant_args = names_to_scheme[weight_name]
198+
quant_args = names_to_scheme[weight_name].weights
202199
decompressed = self.decompress_weight(
203200
compressed_data=weight_data, quantization_args=quant_args
204201
)
205202
yield merge_names(weight_name, "weight"), decompressed
203+
204+
205+
def _is_symmetric(name: str, names_to_scheme: Dict[str, QuantizationScheme]) -> bool:
206+
weight_name, zp_name = name.rsplit(".", 1)
207+
scheme = names_to_scheme[weight_name]
208+
209+
if zp_name == "weight_zero_point":
210+
quant_args = scheme.weights
211+
if zp_name == "input_zero_point":
212+
quant_args = scheme.input_activations
213+
if zp_name == "output_zero_point":
214+
quant_args = scheme.output_activations
215+
216+
assert quant_args is not None
217+
return quant_args.symmetric

src/compressed_tensors/compressors/sparse_compressors/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def compress(
7676
_LOGGER.debug(
7777
f"Compressing model with {len(model_state)} parameterized layers..."
7878
)
79-
for name, value in tqdm(model_state.items(), desc="Compressing model"):
79+
for name, value in tqdm(model_state.items(), desc="Compressing with sparsity"):
8080
if not self.should_compress(name, compression_targets):
8181
compressed_dict[name] = value
8282
continue

src/compressed_tensors/utils/helpers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"shard_tensor",
3939
"pack_bitmasks",
4040
"unpack_bitmasks",
41+
"remove_suffix",
4142
]
4243

4344
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
@@ -328,3 +329,9 @@ def unpack_bitmasks(
328329
)
329330

330331
return unpacked_bitmasks_torch
332+
333+
334+
def remove_suffix(value: str, suffix: str) -> str:
335+
# can replace with str.removesuffix in python3.9+
336+
assert value.endswith(suffix)
337+
return value[: -len(suffix)]

0 commit comments

Comments
 (0)