Skip to content

Commit e7c6ef4

Browse files
authored
[NVFP4] Fix onloading of fused layers (#1512)
Summary - Properly onload qkv and gate/up layers when updating global scales with cpu offloading Testing: - Tested in memory-retrained case to ensure proper behaviour
1 parent 559ad81 commit e7c6ef4

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

src/llmcompressor/modifiers/utils/helpers.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
from compressed_tensors.quantization import QuantizationStrategy
5-
from compressed_tensors.utils import align_module_device, update_parameter_data
5+
from compressed_tensors.utils import align_modules, update_parameter_data
66
from torch.nn import Linear, Module
77

88
__all__ = ["update_fused_layer_weight_global_scales"]
@@ -51,17 +51,17 @@ def _valid_tensor_group_quant(layer_list: List[Linear]):
5151
return False
5252
return True
5353

54-
with align_module_device(submodule):
55-
if _is_attention_module(submodule):
56-
# already fused/treated as one layer
57-
if hasattr(submodule, "qkv_proj"):
58-
return
54+
if _is_attention_module(submodule):
55+
# already fused/treated as one layer
56+
if hasattr(submodule, "qkv_proj"):
57+
return
5958

60-
if not _valid_tensor_group_quant(
61-
[submodule.q_proj, submodule.v_proj, submodule.k_proj]
62-
):
63-
return
59+
if not _valid_tensor_group_quant(
60+
[submodule.q_proj, submodule.v_proj, submodule.k_proj]
61+
):
62+
return
6463

64+
with align_modules([submodule.q_proj, submodule.v_proj, submodule.k_proj]):
6565
global_scale = torch.min(
6666
torch.cat(
6767
(
@@ -70,29 +70,31 @@ def _valid_tensor_group_quant(layer_list: List[Linear]):
7070
submodule.v_proj.weight_global_scale.data,
7171
)
7272
)
73-
)
73+
).reshape([1])
7474

7575
update_parameter_data(submodule.q_proj, global_scale, "weight_global_scale")
7676
update_parameter_data(submodule.k_proj, global_scale, "weight_global_scale")
7777
update_parameter_data(submodule.v_proj, global_scale, "weight_global_scale")
78+
del global_scale
7879

79-
with align_module_device(submodule):
80-
if _is_mlp_module(submodule):
81-
if not _valid_tensor_group_quant([submodule.gate_proj, submodule.up_proj]):
82-
return
80+
if _is_mlp_module(submodule):
81+
if not _valid_tensor_group_quant([submodule.gate_proj, submodule.up_proj]):
82+
return
8383

84+
with align_modules([submodule.gate_proj, submodule.up_proj]):
8485
global_scale = torch.min(
8586
torch.cat(
8687
(
8788
submodule.gate_proj.weight_global_scale.data,
8889
submodule.up_proj.weight_global_scale.data,
8990
)
9091
)
91-
)
92+
).reshape([1])
9293

9394
update_parameter_data(
9495
submodule.gate_proj, global_scale, "weight_global_scale"
9596
)
9697
update_parameter_data(
9798
submodule.up_proj, global_scale, "weight_global_scale"
9899
)
100+
del global_scale

0 commit comments

Comments
 (0)