Skip to content

Commit 2b00d04

Browse files
authored
[Transform] Norm fusing utilities (#1637)
## Purpose ## * Provide utilities for fusing norms and embeddings for SpinQuantModifier ## Changes ## * Implement `center_embeddings` and `fuse_norm_linears` * `center_embeddings` doesn't seem to have an effect (and theoretically shouldn't have an effect, and makes the implementation less resilient), but is implemented by the SpinQuant paper. We can implement the utility here and decide to not use it later ## Testing ## * Add `test_center_embeddings` and `test_fuse_norm_linears` --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 357437f commit 2b00d04

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

src/llmcompressor/modeling/fuse.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from typing import Iterable
2+
3+
import torch
4+
from compressed_tensors import (
5+
align_module_device,
6+
get_execution_device,
7+
update_offload_parameter,
8+
)
9+
10+
__all__ = ["center_embeddings", "fuse_norm_linears"]
11+
12+
13+
PRECISION = torch.float64
14+
15+
16+
def center_embeddings(embedding: torch.nn.Module):
17+
"""
18+
Shift each embedding to have a mean of zero
19+
20+
:param embedding: embedding module containing embeddings to center
21+
"""
22+
if not hasattr(embedding, "weight"):
23+
raise ValueError(f"Cannot fuse norm of type {type(embedding)}")
24+
25+
with align_module_device(embedding):
26+
weight_dtype = embedding.weight.dtype
27+
weight = embedding.weight.to(PRECISION)
28+
new_weight = weight - weight.mean(dim=-1, keepdim=True)
29+
new_weight = new_weight.to(weight_dtype)
30+
31+
update_offload_parameter(embedding, "weight", new_weight)
32+
33+
34+
def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]):
35+
"""
36+
Fuse the scaling operation of norm layer into subsequent linear layers.
37+
This useful for ensuring transform invariance between norm and linear layers.
38+
39+
Note that unitary transforms (rotation) commute with normalization, but not scaling
40+
41+
:param norm: norm layer whose weight will be fused into subsequent linears
42+
:param linears: linear layers which directly follow the norm layer
43+
"""
44+
if not hasattr(norm, "weight"):
45+
raise ValueError(f"Cannot fuse norm of type {type(norm)}")
46+
47+
for linear in linears:
48+
# NOTE: spinquant does this op in float64
49+
exec_device = get_execution_device(norm)
50+
with align_module_device(norm, exec_device), align_module_device(
51+
linear, exec_device
52+
):
53+
weight_dtype = linear.weight.dtype
54+
new_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION)
55+
new_weight = new_weight.to(weight_dtype)
56+
57+
update_offload_parameter(linear, "weight", new_weight)
58+
59+
new_norm_weight = torch.ones_like(norm.weight, device="cpu")
60+
update_offload_parameter(norm, "weight", new_norm_weight)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import pytest
2+
import torch
3+
4+
from llmcompressor.modeling.fuse import center_embeddings, fuse_norm_linears
5+
6+
7+
@pytest.mark.unit
8+
def test_center_embeddings():
9+
embedding = torch.nn.Embedding(10, 10)
10+
center_embeddings(embedding)
11+
12+
assert torch.allclose(
13+
embedding.weight.mean(dim=1), torch.zeros(embedding.num_embeddings), atol=1e-5
14+
)
15+
16+
17+
@pytest.mark.unit
18+
def test_fuse_norm_linears():
19+
norm = torch.nn.LayerNorm((5,))
20+
norm.weight.data = torch.rand(norm.weight.shape)
21+
linears = [
22+
torch.nn.Linear(5, 5),
23+
torch.nn.Linear(5, 5),
24+
]
25+
26+
input = torch.rand((1, 5), requires_grad=False)
27+
true_output = torch.stack([linear(norm(input)) for linear in linears])
28+
29+
fuse_norm_linears(norm, linears)
30+
output = torch.stack([linear(norm(input)) for linear in linears])
31+
32+
assert torch.allclose(true_output, output)

0 commit comments

Comments
 (0)