Skip to content

Commit b37b7ff

Browse files
committed
rearrange to clean up diff
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent a217cb2 commit b37b7ff

File tree

1 file changed

+18
-18
lines changed
  • src/compressed_tensors/compressors/quantized_compressors

1 file changed

+18
-18
lines changed

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -189,19 +189,21 @@ def decompress(
189189
path_to_model_or_tensors, names_to_scheme
190190
)
191191

192-
def decompress_from_state_dict(
192+
def _decompress_from_path(
193193
self,
194-
state_dict: Dict[str, torch.Tensor],
194+
path_to_model: Union[str, Path, Dict[str, Any]],
195195
names_to_scheme: Dict[str, QuantizationScheme],
196-
) -> Generator[Tuple[str, Dict[str, torch.Tensor]], None, None]:
197-
weight_mappings = get_nested_mappings_from_state_dict(
198-
state_dict, self.compression_param_names
196+
device: str,
197+
):
198+
weight_mappings = get_nested_weight_mappings(
199+
path_to_model, self.compression_param_names
199200
)
200201
for module_path in weight_mappings.keys():
201202
weight_data = {}
202-
for param_name, param_value in weight_mappings[module_path].items():
203-
weight_data[param_name] = param_value
204-
203+
for param_name, safe_path in weight_mappings[module_path].items():
204+
full_name = merge_names(module_path, param_name)
205+
with safe_open(safe_path, framework="pt", device=device) as f:
206+
weight_data[param_name] = f.get_tensor(full_name)
205207
if "weight_scale" in weight_data:
206208
quant_args = names_to_scheme[module_path].weights
207209
decompressed = self.decompress_weight(
@@ -210,21 +212,19 @@ def decompress_from_state_dict(
210212
weight_data["weight"] = decompressed
211213
yield module_path, weight_data
212214

213-
def _decompress_from_path(
215+
def decompress_from_state_dict(
214216
self,
215-
path_to_model: Union[str, Path, Dict[str, Any]],
217+
state_dict: Dict[str, torch.Tensor],
216218
names_to_scheme: Dict[str, QuantizationScheme],
217-
device: str,
218-
):
219-
weight_mappings = get_nested_weight_mappings(
220-
path_to_model, self.compression_param_names
219+
) -> Generator[Tuple[str, Dict[str, torch.Tensor]], None, None]:
220+
weight_mappings = get_nested_mappings_from_state_dict(
221+
state_dict, self.compression_param_names
221222
)
222223
for module_path in weight_mappings.keys():
223224
weight_data = {}
224-
for param_name, safe_path in weight_mappings[module_path].items():
225-
full_name = merge_names(module_path, param_name)
226-
with safe_open(safe_path, framework="pt", device=device) as f:
227-
weight_data[param_name] = f.get_tensor(full_name)
225+
for param_name, param_value in weight_mappings[module_path].items():
226+
weight_data[param_name] = param_value
227+
228228
if "weight_scale" in weight_data:
229229
quant_args = names_to_scheme[module_path].weights
230230
decompressed = self.decompress_weight(

0 commit comments

Comments
 (0)