@@ -48,6 +48,7 @@ def apply_spinquant(
48
48
use_r2 = False ,
49
49
use_r4 = True ,
50
50
pretrained_rotation_path = None ,
51
+ qkv_split = False ,
51
52
):
52
53
"""
53
54
Apply SpinQuant to a Transformer model: https://arxiv.org/abs/2405.16406
@@ -57,9 +58,9 @@ def apply_spinquant(
57
58
which appears to show best results in many cases (see https://github.com/pytorch/ao/pull/983).
58
59
59
60
Note that the R3 rotation matrix and Cayley optimization for R1/R2 are currently not implemented.
60
- """
61
- assert isinstance (model , Transformer ), "Only Transformer models are supported"
62
61
62
+ qkv_split should be set to True if attention modules have separate tensors wq, wk, wv instead of wqkv
63
+ """
63
64
original_device = next (model .parameters ()).device
64
65
device = "cuda" if torch .cuda .is_available () else "cpu"
65
66
model .to (device = device )
@@ -75,18 +76,21 @@ def apply_spinquant(
75
76
assert Path (pretrained_rotation_path ).suffix == ".bin" , "Expected a .bin file."
76
77
77
78
if use_r1 :
78
- fuse_layernorm_into_linear (model )
79
- apply_spinquant_r1 (model , device , pretrained_rotation_path )
79
+ fuse_layernorm_into_linear (model , qkv_split )
80
+ apply_spinquant_r1 (model , device , pretrained_rotation_path , qkv_split )
80
81
if use_r2 :
81
- apply_spinquant_r2 (model , device , pretrained_rotation_path )
82
+ apply_spinquant_r2 (model , device , pretrained_rotation_path , qkv_split )
82
83
if use_r4 :
83
84
apply_spinquant_r4 (model , device )
84
85
85
86
model .to (device = original_device )
86
87
87
88
88
- def apply_spinquant_r1 (model , device , pretrained_rotation_path = None ):
89
- """Apply the SpinQuant R1 rotation matrix to the model."""
89
+ def apply_spinquant_r1 (model , device , pretrained_rotation_path = None , qkv_split = False ):
90
+ """
91
+ Apply the SpinQuant R1 rotation matrix to the model.
92
+ qkv_split should be set to True if attention modules have separate tensors wq, wk, wv instead of wqkv
93
+ """
90
94
91
95
if pretrained_rotation_path is not None :
92
96
R1 = torch .load (pretrained_rotation_path )["R1" ].to (device ).to (torch .float64 )
@@ -97,11 +101,14 @@ def apply_spinquant_r1(model, device, pretrained_rotation_path=None):
97
101
else :
98
102
R1 = random_hadamard_matrix (model .config .dim , device )
99
103
100
- _rotate_model_r1 (model , R1 )
104
+ _rotate_model_r1 (model , R1 , qkv_split = qkv_split )
101
105
102
106
103
- def apply_spinquant_r2 (model , device , pretrained_rotation_path = None ):
104
- """Apply the SpinQuant R2 rotation matrices to the model."""
107
+ def apply_spinquant_r2 (model , device , pretrained_rotation_path = None , qkv_split = False ):
108
+ """
109
+ Apply the SpinQuant R2 rotation matrices to the model.
110
+ qkv_split should be set to True if attention modules have separate tensors wq, wk, wv instead of wqkv
111
+ """
105
112
106
113
R2s = [] # note that unlike R1, there are multiple R2 matrices (one per layer)
107
114
head_dim = model .config .head_dim
@@ -118,7 +125,7 @@ def apply_spinquant_r2(model, device, pretrained_rotation_path=None):
118
125
R2 = random_hadamard_matrix (head_dim , device )
119
126
R2s .append (R2 )
120
127
121
- _rotate_model_r2 (model , R2s )
128
+ _rotate_model_r2 (model , R2s , qkv_split = qkv_split )
122
129
123
130
124
131
def apply_spinquant_r4 (model , device ):
@@ -154,19 +161,19 @@ def _fuse_layernorm_into_linear(
154
161
155
162
156
163
@torch .no_grad ()
157
- def _rotate_model_r1 (model , R1 ):
164
+ def _rotate_model_r1 (model , R1 , qkv_split = False ):
158
165
_rotate_embeddings (model , R1 )
159
166
_rotate_head (model , R1 )
160
167
161
168
for layer in model .layers :
162
- _rotate_attention_inputs (layer , R1 )
169
+ _rotate_attention_inputs (layer , R1 , qkv_split = qkv_split )
163
170
_rotate_attention_output (layer , R1 )
164
171
_rotate_mlp_input (layer , R1 )
165
172
_rotate_mlp_output (layer , R1 )
166
173
167
174
168
175
@torch .no_grad ()
169
- def _rotate_model_r2 (model , R2s ):
176
+ def _rotate_model_r2 (model , R2s , qkv_split = False ):
170
177
"""Rotate the W_v and W_o weights of the multi-head self-attention modules."""
171
178
172
179
head_dim = model .config .head_dim
@@ -180,25 +187,28 @@ def _rotate_model_r2(model, R2s):
180
187
# Rotate W_o
181
188
apply_exact_had_to_linear (attn .wo , had_dim = head_dim , output = False , R2 = R2 )
182
189
183
- # Extract W_v
184
- kv_size = model .config .n_local_heads * head_dim
185
- wq , wk , wv = attn .wqkv .weight .data .split (
186
- [model .config .dim , kv_size , kv_size ], dim = 0
187
- )
188
- out_features , in_features = wv .shape
189
- wv_mod = nn .Linear (
190
- in_features ,
191
- out_features ,
192
- bias = attn .wqkv .bias is not None ,
193
- device = wv .device ,
194
- dtype = wv .dtype ,
195
- )
196
- wv_mod .weight .data = wv
190
+ if qkv_split :
191
+ apply_exact_had_to_linear (attn .wv , had_dim = head_dim , output = True , R2 = R2 )
192
+ else :
193
+ # Extract W_v
194
+ kv_size = model .config .n_local_heads * head_dim
195
+ wq , wk , wv = attn .wqkv .weight .data .split (
196
+ [model .config .dim , kv_size , kv_size ], dim = 0
197
+ )
198
+ out_features , in_features = wv .shape
199
+ wv_mod = nn .Linear (
200
+ in_features ,
201
+ out_features ,
202
+ bias = attn .wqkv .bias is not None ,
203
+ device = wv .device ,
204
+ dtype = wv .dtype ,
205
+ )
206
+ wv_mod .weight .data = wv
197
207
198
- # Rotate W_v
199
- apply_exact_had_to_linear (wv_mod , had_dim = head_dim , output = True , R2 = R2 )
208
+ # Rotate W_v
209
+ apply_exact_had_to_linear (wv_mod , had_dim = head_dim , output = True , R2 = R2 )
200
210
201
- attn .wqkv .weight .data = torch .cat ([wq , wk , wv_mod .weight .data ], dim = 0 )
211
+ attn .wqkv .weight .data = torch .cat ([wq , wk , wv_mod .weight .data ], dim = 0 )
202
212
203
213
204
214
@torch .no_grad ()
@@ -226,12 +236,14 @@ def _add_activation_wrappers_r4(model):
226
236
227
237
228
238
@torch .no_grad ()
229
- def fuse_layernorm_into_linear (model ):
239
+ def fuse_layernorm_into_linear (model , qkv_split = False ):
230
240
"""
231
241
Fuse RMSNorm weights into the subsequent linear layers.
232
242
233
243
This is done in the paper specifically to make pre-norm LLMs like LLaMa
234
244
rotation-invariant when quantization is not present.
245
+
246
+ qkv_split should be set to True if attention modules have separate tensors wq, wk, wv instead of wqkv
235
247
"""
236
248
# Embedding fusion (from SpinQuant repo: utils/fuse_norm_utils.py:43)
237
249
# I currently don't understand why this is necessary, so I contacted the
@@ -244,7 +256,13 @@ def fuse_layernorm_into_linear(model):
244
256
_fuse_layernorm_into_linear (
245
257
layer .ffn_norm , [layer .feed_forward .w1 , layer .feed_forward .w3 ]
246
258
)
247
- _fuse_layernorm_into_linear (layer .attention_norm , [layer .attention .wqkv ])
259
+ if qkv_split :
260
+ _fuse_layernorm_into_linear (
261
+ layer .attention_norm ,
262
+ [layer .attention .wq , layer .attention .wk , layer .attention .wv ],
263
+ )
264
+ else :
265
+ _fuse_layernorm_into_linear (layer .attention_norm , [layer .attention .wqkv ])
248
266
249
267
_fuse_layernorm_into_linear (model .norm , [model .output ])
250
268
@@ -270,8 +288,13 @@ def _rotate_attention_output(layer, R1):
270
288
mod .bias .data = torch .matmul (R1 .T , b ).to (dtype = mod .weight .dtype )
271
289
272
290
273
- def _rotate_attention_inputs (layer , R1 ):
274
- _rotate_mod_weight_right (layer .attention .wqkv , R1 )
291
+ def _rotate_attention_inputs (layer , R1 , qkv_split = False ):
292
+ if qkv_split :
293
+ _rotate_mod_weight_right (layer .attention .wq , R1 )
294
+ _rotate_mod_weight_right (layer .attention .wk , R1 )
295
+ _rotate_mod_weight_right (layer .attention .wv , R1 )
296
+ else :
297
+ _rotate_mod_weight_right (layer .attention .wqkv , R1 )
275
298
276
299
277
300
def _rotate_head (model , R1 ):
0 commit comments