2
2
3
3
import torch
4
4
from compressed_tensors .quantization import QuantizationStrategy
5
- from compressed_tensors .utils import align_module_device , update_parameter_data
5
+ from compressed_tensors .utils import align_modules , update_parameter_data
6
6
from torch .nn import Linear , Module
7
7
8
8
__all__ = ["update_fused_layer_weight_global_scales" ]
@@ -51,17 +51,17 @@ def _valid_tensor_group_quant(layer_list: List[Linear]):
51
51
return False
52
52
return True
53
53
54
- with align_module_device (submodule ):
55
- if _is_attention_module (submodule ):
56
- # already fused/treated as one layer
57
- if hasattr (submodule , "qkv_proj" ):
58
- return
54
+ if _is_attention_module (submodule ):
55
+ # already fused/treated as one layer
56
+ if hasattr (submodule , "qkv_proj" ):
57
+ return
59
58
60
- if not _valid_tensor_group_quant (
61
- [submodule .q_proj , submodule .v_proj , submodule .k_proj ]
62
- ):
63
- return
59
+ if not _valid_tensor_group_quant (
60
+ [submodule .q_proj , submodule .v_proj , submodule .k_proj ]
61
+ ):
62
+ return
64
63
64
+ with align_modules ([submodule .q_proj , submodule .v_proj , submodule .k_proj ]):
65
65
global_scale = torch .min (
66
66
torch .cat (
67
67
(
@@ -70,29 +70,31 @@ def _valid_tensor_group_quant(layer_list: List[Linear]):
70
70
submodule .v_proj .weight_global_scale .data ,
71
71
)
72
72
)
73
- )
73
+ ). reshape ([ 1 ])
74
74
75
75
update_parameter_data (submodule .q_proj , global_scale , "weight_global_scale" )
76
76
update_parameter_data (submodule .k_proj , global_scale , "weight_global_scale" )
77
77
update_parameter_data (submodule .v_proj , global_scale , "weight_global_scale" )
78
+ del global_scale
78
79
79
- with align_module_device (submodule ):
80
- if _is_mlp_module (submodule ):
81
- if not _valid_tensor_group_quant ([submodule .gate_proj , submodule .up_proj ]):
82
- return
80
+ if _is_mlp_module (submodule ):
81
+ if not _valid_tensor_group_quant ([submodule .gate_proj , submodule .up_proj ]):
82
+ return
83
83
84
+ with align_modules ([submodule .gate_proj , submodule .up_proj ]):
84
85
global_scale = torch .min (
85
86
torch .cat (
86
87
(
87
88
submodule .gate_proj .weight_global_scale .data ,
88
89
submodule .up_proj .weight_global_scale .data ,
89
90
)
90
91
)
91
- )
92
+ ). reshape ([ 1 ])
92
93
93
94
update_parameter_data (
94
95
submodule .gate_proj , global_scale , "weight_global_scale"
95
96
)
96
97
update_parameter_data (
97
98
submodule .up_proj , global_scale , "weight_global_scale"
98
99
)
100
+ del global_scale
0 commit comments