Skip to content

Commit 59f02b5

Browse files
authored
update name (#310)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent c84b5b4 commit 59f02b5

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -573,8 +573,8 @@ def _replace_weights(self, dense_weight_generator, model: Module):
573573
:param model: The model whose weights are to be updated.
574574
"""
575575

576-
for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
577-
module = operator.attrgetter(name)(model)
576+
for mod_path, data in tqdm(dense_weight_generator, desc="Decompressing model"):
577+
module = operator.attrgetter(mod_path)(model)
578578

579579
params_device = next(module.parameters()).device
580580
device = "cpu" if has_offloaded_params(module) else params_device

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -195,33 +195,33 @@ def _decompress_from_path(
195195
weight_mappings = get_nested_weight_mappings(
196196
path_to_model, self.compression_param_names
197197
)
198-
for weight_name in weight_mappings.keys():
198+
for module_path in weight_mappings.keys():
199199
weight_data = {}
200-
for param_name, safe_path in weight_mappings[weight_name].items():
201-
full_name = merge_names(weight_name, param_name)
200+
for param_name, safe_path in weight_mappings[module_path].items():
201+
full_name = merge_names(module_path, param_name)
202202
with safe_open(safe_path, framework="pt", device=device) as f:
203203
weight_data[param_name] = f.get_tensor(full_name)
204204
if "weight_scale" in weight_data:
205-
quant_args = names_to_scheme[weight_name].weights
205+
quant_args = names_to_scheme[module_path].weights
206206
decompressed = self.decompress_weight(
207207
compressed_data=weight_data, quantization_args=quant_args
208208
)
209209
weight_data["weight"] = decompressed
210-
yield weight_name, weight_data
210+
yield module_path, weight_data
211211

212212
def _decompress_from_state_dict(self, state_dict, names_to_scheme):
213213
weight_mappings = get_nested_mappings_from_state_dict(
214214
state_dict, self.compression_param_names
215215
)
216-
for weight_name in weight_mappings.keys():
216+
for module_path in weight_mappings.keys():
217217
weight_data = {}
218-
for param_name, param_value in weight_mappings[weight_name].items():
218+
for param_name, param_value in weight_mappings[module_path].items():
219219
weight_data[param_name] = param_value
220220

221221
if "weight_scale" in weight_data:
222-
quant_args = names_to_scheme[weight_name]
222+
quant_args = names_to_scheme[module_path]
223223
decompressed = self.decompress_weight(
224224
compressed_data=weight_data, quantization_args=quant_args
225225
)
226226
weight_data["weight"] = decompressed
227-
yield weight_name, weight_data
227+
yield module_path, weight_data

0 commit comments

Comments
 (0)