|
41 | 41 | load_pretrained_quantization,
|
42 | 42 | )
|
43 | 43 | from compressed_tensors.quantization.lifecycle import expand_target_names
|
44 |
| -from compressed_tensors.quantization.quant_args import QuantizationArgs |
45 | 44 | from compressed_tensors.quantization.utils import (
|
46 | 45 | is_module_quantized,
|
47 | 46 | iter_named_leaf_modules,
|
|
62 | 61 | from transformers.file_utils import CONFIG_NAME
|
63 | 62 |
|
64 | 63 |
|
65 |
| -__all__ = ["ModelCompressor", "map_modules_to_quant_scheme"] |
| 64 | +__all__ = ["ModelCompressor", "map_module_to_scheme"] |
66 | 65 |
|
67 | 66 | _LOGGER: logging.Logger = logging.getLogger(__name__)
|
68 | 67 |
|
@@ -373,11 +372,8 @@ def compress(
|
373 | 372 | if state_dict is None:
|
374 | 373 | state_dict = model.state_dict()
|
375 | 374 |
|
376 |
| - module_to_scheme: Dict[str, QuantizationScheme] = map_modules_to_quant_scheme( |
377 |
| - model |
378 |
| - ) |
379 |
| - |
380 | 375 | if self.quantization_compressor is not None:
|
| 376 | + module_to_scheme = map_module_to_scheme(model) |
381 | 377 | state_dict = self.quantization_compressor.compress(
|
382 | 378 | state_dict, names_to_scheme=module_to_scheme
|
383 | 379 | )
|
@@ -521,7 +517,7 @@ def _replace_weights(self, dense_weight_generator, model: Module):
|
521 | 517 | update_parameter_data(module, data, param_name)
|
522 | 518 |
|
523 | 519 |
|
524 |
| -def map_modules_to_quant_scheme( |
| 520 | +def map_module_to_scheme( |
525 | 521 | model: Module,
|
526 | 522 | ) -> Dict[str, QuantizationScheme]:
|
527 | 523 | """
|
|
0 commit comments