Skip to content

Commit 3addf30

Browse files
authored
SpinQuant support split qkv (#2547)
SpinQuant support split qkv (#2547) Summary: Extend SpinQuant to support models where there are separate wq, wk, wv tensors in attention modules (instead of combined wqkv) Reviewed By: andrewor14 Differential Revision: D78280564
1 parent 975bd57 commit 3addf30

File tree

2 files changed

+60
-37
lines changed

2 files changed

+60
-37
lines changed

torchao/prototype/spinquant/hadamard_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def get_hadK(n, transpose=False):
175175
hadK = get_had12().T if transpose else get_had12()
176176
else:
177177
assert is_pow2(n)
178-
178+
hadK = torch.FloatTensor([[1]])
179179
K = 1
180180

181181
return hadK, K
@@ -222,7 +222,7 @@ def matmul_hadU_fast(X, hadK, K):
222222

223223
def random_hadamard_matrix(size, device, seed=0):
224224
# See https://cornell-relaxml.github.io/quip-sharp/ , Section "Randomized Hadamard Transformation"
225-
gen = torch.Generator()
225+
gen = torch.Generator(device=device)
226226
gen.manual_seed(seed)
227227
Q = torch.randint(low=0, high=2, size=(size,), generator=gen).to(torch.float64)
228228
Q = Q * 2 - 1

torchao/prototype/spinquant/spinquant.py

Lines changed: 58 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def apply_spinquant(
4848
use_r2=False,
4949
use_r4=True,
5050
pretrained_rotation_path=None,
51+
qkv_split=False,
5152
):
5253
"""
5354
Apply SpinQuant to a Transformer model: https://arxiv.org/abs/2405.16406
@@ -57,9 +58,9 @@ def apply_spinquant(
5758
which appears to show best results in many cases (see https://github.com/pytorch/ao/pull/983).
5859
5960
Note that the R3 rotation matrix and Cayley optimization for R1/R2 are currently not implemented.
60-
"""
61-
assert isinstance(model, Transformer), "Only Transformer models are supported"
6261
62+
qkv_split should be set to True if attention modules have separate tensors wq, wk, wv instead of wqkv
63+
"""
6364
original_device = next(model.parameters()).device
6465
device = "cuda" if torch.cuda.is_available() else "cpu"
6566
model.to(device=device)
@@ -75,18 +76,21 @@ def apply_spinquant(
7576
assert Path(pretrained_rotation_path).suffix == ".bin", "Expected a .bin file."
7677

7778
if use_r1:
78-
fuse_layernorm_into_linear(model)
79-
apply_spinquant_r1(model, device, pretrained_rotation_path)
79+
fuse_layernorm_into_linear(model, qkv_split)
80+
apply_spinquant_r1(model, device, pretrained_rotation_path, qkv_split)
8081
if use_r2:
81-
apply_spinquant_r2(model, device, pretrained_rotation_path)
82+
apply_spinquant_r2(model, device, pretrained_rotation_path, qkv_split)
8283
if use_r4:
8384
apply_spinquant_r4(model, device)
8485

8586
model.to(device=original_device)
8687

8788

88-
def apply_spinquant_r1(model, device, pretrained_rotation_path=None):
89-
"""Apply the SpinQuant R1 rotation matrix to the model."""
89+
def apply_spinquant_r1(model, device, pretrained_rotation_path=None, qkv_split=False):
90+
"""
91+
Apply the SpinQuant R1 rotation matrix to the model.
92+
qkv_split should be set to True if attention modules have separate tensors wq, wk, wv instead of wqkv
93+
"""
9094

9195
if pretrained_rotation_path is not None:
9296
R1 = torch.load(pretrained_rotation_path)["R1"].to(device).to(torch.float64)
@@ -97,11 +101,14 @@ def apply_spinquant_r1(model, device, pretrained_rotation_path=None):
97101
else:
98102
R1 = random_hadamard_matrix(model.config.dim, device)
99103

100-
_rotate_model_r1(model, R1)
104+
_rotate_model_r1(model, R1, qkv_split=qkv_split)
101105

102106

103-
def apply_spinquant_r2(model, device, pretrained_rotation_path=None):
104-
"""Apply the SpinQuant R2 rotation matrices to the model."""
107+
def apply_spinquant_r2(model, device, pretrained_rotation_path=None, qkv_split=False):
108+
"""
109+
Apply the SpinQuant R2 rotation matrices to the model.
110+
qkv_split should be set to True if attention modules have separate tensors wq, wk, wv instead of wqkv
111+
"""
105112

106113
R2s = [] # note that unlike R1, there are multiple R2 matrices (one per layer)
107114
head_dim = model.config.head_dim
@@ -118,7 +125,7 @@ def apply_spinquant_r2(model, device, pretrained_rotation_path=None):
118125
R2 = random_hadamard_matrix(head_dim, device)
119126
R2s.append(R2)
120127

121-
_rotate_model_r2(model, R2s)
128+
_rotate_model_r2(model, R2s, qkv_split=qkv_split)
122129

123130

124131
def apply_spinquant_r4(model, device):
@@ -154,19 +161,19 @@ def _fuse_layernorm_into_linear(
154161

155162

156163
@torch.no_grad()
157-
def _rotate_model_r1(model, R1):
164+
def _rotate_model_r1(model, R1, qkv_split=False):
158165
_rotate_embeddings(model, R1)
159166
_rotate_head(model, R1)
160167

161168
for layer in model.layers:
162-
_rotate_attention_inputs(layer, R1)
169+
_rotate_attention_inputs(layer, R1, qkv_split=qkv_split)
163170
_rotate_attention_output(layer, R1)
164171
_rotate_mlp_input(layer, R1)
165172
_rotate_mlp_output(layer, R1)
166173

167174

168175
@torch.no_grad()
169-
def _rotate_model_r2(model, R2s):
176+
def _rotate_model_r2(model, R2s, qkv_split=False):
170177
"""Rotate the W_v and W_o weights of the multi-head self-attention modules."""
171178

172179
head_dim = model.config.head_dim
@@ -180,25 +187,28 @@ def _rotate_model_r2(model, R2s):
180187
# Rotate W_o
181188
apply_exact_had_to_linear(attn.wo, had_dim=head_dim, output=False, R2=R2)
182189

183-
# Extract W_v
184-
kv_size = model.config.n_local_heads * head_dim
185-
wq, wk, wv = attn.wqkv.weight.data.split(
186-
[model.config.dim, kv_size, kv_size], dim=0
187-
)
188-
out_features, in_features = wv.shape
189-
wv_mod = nn.Linear(
190-
in_features,
191-
out_features,
192-
bias=attn.wqkv.bias is not None,
193-
device=wv.device,
194-
dtype=wv.dtype,
195-
)
196-
wv_mod.weight.data = wv
190+
if qkv_split:
191+
apply_exact_had_to_linear(attn.wv, had_dim=head_dim, output=True, R2=R2)
192+
else:
193+
# Extract W_v
194+
kv_size = model.config.n_local_heads * head_dim
195+
wq, wk, wv = attn.wqkv.weight.data.split(
196+
[model.config.dim, kv_size, kv_size], dim=0
197+
)
198+
out_features, in_features = wv.shape
199+
wv_mod = nn.Linear(
200+
in_features,
201+
out_features,
202+
bias=attn.wqkv.bias is not None,
203+
device=wv.device,
204+
dtype=wv.dtype,
205+
)
206+
wv_mod.weight.data = wv
197207

198-
# Rotate W_v
199-
apply_exact_had_to_linear(wv_mod, had_dim=head_dim, output=True, R2=R2)
208+
# Rotate W_v
209+
apply_exact_had_to_linear(wv_mod, had_dim=head_dim, output=True, R2=R2)
200210

201-
attn.wqkv.weight.data = torch.cat([wq, wk, wv_mod.weight.data], dim=0)
211+
attn.wqkv.weight.data = torch.cat([wq, wk, wv_mod.weight.data], dim=0)
202212

203213

204214
@torch.no_grad()
@@ -226,12 +236,14 @@ def _add_activation_wrappers_r4(model):
226236

227237

228238
@torch.no_grad()
229-
def fuse_layernorm_into_linear(model):
239+
def fuse_layernorm_into_linear(model, qkv_split=False):
230240
"""
231241
Fuse RMSNorm weights into the subsequent linear layers.
232242
233243
This is done in the paper specifically to make pre-norm LLMs like LLaMa
234244
rotation-invariant when quantization is not present.
245+
246+
qkv_split should be set to True if attention modules have separate tensors wq, wk, wv instead of wqkv
235247
"""
236248
# Embedding fusion (from SpinQuant repo: utils/fuse_norm_utils.py:43)
237249
# I currently don't understand why this is necessary, so I contacted the
@@ -244,7 +256,13 @@ def fuse_layernorm_into_linear(model):
244256
_fuse_layernorm_into_linear(
245257
layer.ffn_norm, [layer.feed_forward.w1, layer.feed_forward.w3]
246258
)
247-
_fuse_layernorm_into_linear(layer.attention_norm, [layer.attention.wqkv])
259+
if qkv_split:
260+
_fuse_layernorm_into_linear(
261+
layer.attention_norm,
262+
[layer.attention.wq, layer.attention.wk, layer.attention.wv],
263+
)
264+
else:
265+
_fuse_layernorm_into_linear(layer.attention_norm, [layer.attention.wqkv])
248266

249267
_fuse_layernorm_into_linear(model.norm, [model.output])
250268

@@ -270,8 +288,13 @@ def _rotate_attention_output(layer, R1):
270288
mod.bias.data = torch.matmul(R1.T, b).to(dtype=mod.weight.dtype)
271289

272290

273-
def _rotate_attention_inputs(layer, R1):
274-
_rotate_mod_weight_right(layer.attention.wqkv, R1)
291+
def _rotate_attention_inputs(layer, R1, qkv_split=False):
292+
if qkv_split:
293+
_rotate_mod_weight_right(layer.attention.wq, R1)
294+
_rotate_mod_weight_right(layer.attention.wk, R1)
295+
_rotate_mod_weight_right(layer.attention.wv, R1)
296+
else:
297+
_rotate_mod_weight_right(layer.attention.wqkv, R1)
275298

276299

277300
def _rotate_head(model, R1):

0 commit comments

Comments
 (0)