Skip to content

Commit 1c4f639

Browse files
authored
[NVFP4] Fix global scale update when dealing with offloaded layers (#1554)
SUMMARY: - Updating the global scale using the `align_module` context does not persist the scale parameter - Update outside of the context so that the offloaded dict is upadated as well Testing - Resolves CPU offloading issues seen with a Llama 70b FP4
1 parent fa0f793 commit 1c4f639

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

src/llmcompressor/modifiers/utils/helpers.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,11 @@ def _valid_tensor_group_quant(layer_list: List[Linear]):
7272
)
7373
).reshape([1])
7474

75-
update_parameter_data(submodule.q_proj, global_scale, "weight_global_scale")
76-
update_parameter_data(submodule.k_proj, global_scale, "weight_global_scale")
77-
update_parameter_data(submodule.v_proj, global_scale, "weight_global_scale")
78-
del global_scale
75+
update_parameter_data(submodule.k_proj, global_scale, "weight_global_scale")
76+
update_parameter_data(submodule.q_proj, global_scale, "weight_global_scale")
77+
update_parameter_data(submodule.v_proj, global_scale, "weight_global_scale")
78+
79+
del global_scale
7980

8081
if _is_mlp_module(submodule):
8182
if not _valid_tensor_group_quant([submodule.gate_proj, submodule.up_proj]):
@@ -91,10 +92,7 @@ def _valid_tensor_group_quant(layer_list: List[Linear]):
9192
)
9293
).reshape([1])
9394

94-
update_parameter_data(
95-
submodule.gate_proj, global_scale, "weight_global_scale"
96-
)
97-
update_parameter_data(
98-
submodule.up_proj, global_scale, "weight_global_scale"
99-
)
100-
del global_scale
95+
update_parameter_data(submodule.gate_proj, global_scale, "weight_global_scale")
96+
update_parameter_data(submodule.up_proj, global_scale, "weight_global_scale")
97+
98+
del global_scale

0 commit comments

Comments
 (0)