Skip to content

Commit dbc104d

Browse files
committed
wip: decompression works except for zero points
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 0036e21 commit dbc104d

File tree

4 files changed

+112
-48
lines changed

4 files changed

+112
-48
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,12 +382,15 @@ def compress_model(self, model: Module):
382382
for prefix, module in model.named_modules():
383383
if prefix in module_to_scheme or prefix in sparse_compression_targets:
384384
state_dict = module.state_dict(prefix=f"{prefix}.")
385+
# quantization first
385386
if prefix in module_to_scheme:
386387
state_dict = self.quantization_compressor.compress(
387388
state_dict,
388389
names_to_scheme=module_to_scheme,
389390
show_progress=False,
390391
)
392+
393+
# sparsity second
391394
if prefix in sparse_compression_targets:
392395
state_dict = self.sparsity_compressor.compress(
393396
state_dict,
@@ -407,9 +410,46 @@ def compress_model(self, model: Module):
407410

408411
module.quantization_status = QuantizationStatus.COMPRESSED
409412

410-
def decompress_model(model: Module):
413+
def decompress_model(self, model: Module):
414+
module_to_scheme = map_module_to_scheme(model)
415+
sparse_compression_targets: Set[str] = expand_target_names(
416+
model=model,
417+
targets=self.sparsity_config.targets if self.sparsity_config else [],
418+
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
419+
)
420+
421+
for prefix, module in model.named_modules():
422+
if prefix in module_to_scheme or prefix in sparse_compression_targets:
423+
state_dict = module.state_dict(prefix=f"{prefix}.")
424+
decompressed = dict()
425+
# sparsity first
426+
if prefix in sparse_compression_targets:
427+
generator = self.sparsity_compressor.decompress_from_state_dict(
428+
state_dict,
429+
names_to_scheme=module_to_scheme,
430+
)
431+
for _module_name, decompressed_data in generator:
432+
decompressed.update(decompressed_data)
433+
434+
# quantization second
435+
if prefix in module_to_scheme:
436+
generator = self.quantization_compressor.decompress_from_state_dict(
437+
state_dict,
438+
names_to_scheme=module_to_scheme,
439+
)
440+
for _module_name, decompressed_data in generator:
441+
decompressed.update(decompressed_data)
442+
443+
# remove any exist parameters
444+
for name, _ in list(module.named_parameters()):
445+
delattr(module, name)
411446

412-
pass
447+
# replace with decompressed parameters
448+
for name, value in decompressed.items():
449+
param = torch.nn.Parameter(value, requires_grad=False)
450+
register_offload_parameter(module, name, param)
451+
452+
module.quantization_status = QuantizationStatus.FROZEN
413453

414454
# apparently we only have logic for decompressing from a file...
415455

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -185,46 +185,50 @@ def decompress(
185185
)
186186

187187
else:
188-
yield from self._decompress_from_state_dict(
188+
yield from self.decompress_from_state_dict(
189189
path_to_model_or_tensors, names_to_scheme
190190
)
191191

192-
def _decompress_from_path(
192+
def decompress_from_state_dict(
193193
self,
194-
path_to_model: Union[str, Path, Dict[str, Any]],
194+
state_dict: Dict[str, torch.Tensor],
195195
names_to_scheme: Dict[str, QuantizationScheme],
196-
device: str,
197-
):
198-
weight_mappings = get_nested_weight_mappings(
199-
path_to_model, self.compression_param_names
196+
) -> Generator[Tuple[str, Dict[str, torch.Tensor]], None, None]:
197+
weight_mappings = get_nested_mappings_from_state_dict(
198+
state_dict, self.compression_param_names
200199
)
201-
for weight_name in weight_mappings.keys():
200+
for module_name in weight_mappings.keys():
202201
weight_data = {}
203-
for param_name, safe_path in weight_mappings[weight_name].items():
204-
full_name = merge_names(weight_name, param_name)
205-
with safe_open(safe_path, framework="pt", device=device) as f:
206-
weight_data[param_name] = f.get_tensor(full_name)
202+
for param_name, param_value in weight_mappings[module_name].items():
203+
weight_data[param_name] = param_value
204+
207205
if "weight_scale" in weight_data:
208-
quant_args = names_to_scheme[weight_name].weights
206+
quant_args = names_to_scheme[module_name]
209207
decompressed = self.decompress_weight(
210208
compressed_data=weight_data, quantization_args=quant_args
211209
)
212210
weight_data["weight"] = decompressed
213-
yield weight_name, weight_data
211+
yield module_name, weight_data
214212

215-
def _decompress_from_state_dict(self, state_dict, names_to_scheme):
216-
weight_mappings = get_nested_mappings_from_state_dict(
217-
state_dict, self.compression_param_names
213+
def _decompress_from_path(
214+
self,
215+
path_to_model: Union[str, Path, Dict[str, Any]],
216+
names_to_scheme: Dict[str, QuantizationScheme],
217+
device: str,
218+
):
219+
weight_mappings = get_nested_weight_mappings(
220+
path_to_model, self.compression_param_names
218221
)
219-
for weight_name in weight_mappings.keys():
222+
for module_name in weight_mappings.keys():
220223
weight_data = {}
221-
for param_name, param_value in weight_mappings[weight_name].items():
222-
weight_data[param_name] = param_value
223-
224+
for param_name, safe_path in weight_mappings[module_name].items():
225+
full_name = merge_names(module_name, param_name)
226+
with safe_open(safe_path, framework="pt", device=device) as f:
227+
weight_data[param_name] = f.get_tensor(full_name)
224228
if "weight_scale" in weight_data:
225-
quant_args = names_to_scheme[weight_name]
229+
quant_args = names_to_scheme[module_name].weights
226230
decompressed = self.decompress_weight(
227231
compressed_data=weight_data, quantization_args=quant_args
228232
)
229233
weight_data["weight"] = decompressed
230-
yield weight_name, weight_data
234+
yield module_name, weight_data

src/compressed_tensors/compressors/sparse_compressors/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ def decompress(
151151
value = f.get_tensor(ignored_param_name)
152152
yield ignored_param_name, value
153153

154+
def decompress_from_state_dict(self, state_dict, names_to_scheme):
155+
exit(0)
156+
154157
@staticmethod
155158
def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool:
156159
"""

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -370,52 +370,69 @@ def _get_combined_config(s_config, q_config):
370370

371371

372372
@pytest.mark.parametrize(
373-
"model_stub,q_format,s_format",
373+
"model_stub,comp_stub,q_format,s_format",
374374
[
375375
(
376376
"nm-testing/llama2.c-stories42M-gsm8k-quantized-only-uncompressed",
377+
"nm-testing/llama2.c-stories42M-gsm8k-quantized-only-compressed",
377378
"float-quantized",
378379
None,
379380
),
380381
(
381382
"nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed",
383+
"nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed",
382384
None,
383385
"sparse-24-bitmask",
384386
),
385387
(
386388
"nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed",
389+
"nm-testing/llama2.c-stories42M-gsm8k-stacked-compressed",
387390
"float-quantized",
388391
"sparse-24-bitmask",
389392
),
390393
],
391394
)
392-
def test_compress_model(model_stub, q_format, s_format):
395+
def test_compress_decompress_model(model_stub, comp_stub, q_format, s_format):
393396
model = AutoModelForCausalLM.from_pretrained(model_stub)
394397
compressor = ModelCompressor.from_pretrained_model(model, s_format, q_format)
395-
original_compressed_state_dict = dict(compressor.compress(model))
396-
original_compressed_state_dict = {
397-
key: value.clone() for key, value in original_compressed_state_dict.items()
398-
}
399398

400-
compressor.compress_model(model)
399+
# compress model by eagerly compressing state dict
400+
true_compressed = dict(compressor.compress(model))
401+
true_compressed = {key: value.clone() for key, value in true_compressed.items()}
401402

402-
for module in model.modules():
403-
# scheme <=> CompressedLinear
404-
has_scheme = hasattr(module, "quantization_scheme")
405-
is_compressed = (
406-
getattr(module, "quantization_status", None)
407-
== QuantizationStatus.COMPRESSED
408-
)
409-
# assert has_scheme == is_compressed
403+
# compress model directly
404+
compressor.compress_model(model)
405+
compressed = dict(model.state_dict())
410406

411407
# equivalent to eagerly compressing state dict
412-
compressed_state_dict = dict(model.state_dict())
413-
assert compressed_state_dict.keys() == original_compressed_state_dict.keys()
414-
for key in compressed_state_dict.keys():
415-
assert torch.all(
416-
compressed_state_dict[key] == original_compressed_state_dict[key]
417-
), f"{key}"
418-
419-
# decompress
408+
assert compressed.keys() == true_compressed.keys()
409+
for key in compressed.keys():
410+
assert torch.all(compressed[key] == true_compressed[key]), f"{key}"
411+
412+
del compressed
413+
del true_compressed
414+
415+
# -- decompress -- #
416+
417+
# reinstantiate compressor to mimic LLM Compressor flows
420418
compressor = ModelCompressor.from_pretrained_model(model, s_format, q_format)
419+
420+
# decompress model from disk # TODO try also using a model saved from prev step
421+
true_decompressed_model = AutoModelForCausalLM.from_pretrained(
422+
model_stub, device_map="meta"
423+
)
424+
compressor.decompress(comp_stub, true_decompressed_model)
425+
true_decompressed = dict(true_decompressed_model.state_dict())
426+
427+
# decompress model
421428
compressor.decompress_model(model)
429+
decompressed = dict(model.state_dict())
430+
431+
# equivalent to decompressing from disk
432+
breakpoint()
433+
assert decompressed.keys() == true_decompressed.keys()
434+
for key in decompressed.keys():
435+
assert torch.all(decompressed[key] == true_decompressed[key]), f"{key}"
436+
del true_decompressed
437+
438+
exit(0)

0 commit comments

Comments
 (0)