Skip to content

Commit 6ee452e

Browse files
committed
implement head dim
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 0f851d1 commit 6ee452e

File tree

6 files changed

+166
-32
lines changed

6 files changed

+166
-32
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,15 @@ def create_transform(self, module: Module, args: TransformArgs):
5252
:param args: defines how the transform will be applied to the module
5353
"""
5454
assert isinstance(module, Linear)
55-
num_heads = self.scheme.num_heads
56-
size = get_matrix_size(module, args.location, num_heads)
55+
size = get_matrix_size(module, args.location, self.scheme.head_dim)
5756
dtype = module.weight.dtype
5857
device = get_offloaded_device(module)
5958
exec_device = get_execution_device(module)
6059

6160
factory_kwargs = {"construct_device": exec_device}
6261
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
6362
perm = self.perms[weight] if self.scheme.randomize else None
64-
return HadamardTransform(weight, perm, args, num_heads)
63+
return HadamardTransform(weight, perm, args)
6564

6665
def _create_weight(
6766
self,
@@ -86,13 +85,11 @@ def __init__(
8685
weight: Parameter,
8786
perm: Optional[Parameter],
8887
args: TransformArgs,
89-
num_heads: Optional[int],
9088
):
9189
super().__init__()
9290
self.weight = weight
9391
self.perm = perm
9492
self.args = args
95-
self.num_heads = num_heads
9693

9794
def forward(self, value: Tensor) -> Tensor:
9895
weight = self.weight
@@ -103,4 +100,4 @@ def forward(self, value: Tensor) -> Tensor:
103100
if self.args.inverse:
104101
weight = weight.T
105102

106-
return apply_transform_weight(weight, value, self.args.location, self.num_heads)
103+
return apply_transform_weight(weight, value, self.args.location)

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,15 @@ def create_transform(self, module: Module, args: TransformArgs):
5151
:param args: defines how the transform will be applied to the module
5252
"""
5353
assert isinstance(module, Linear)
54-
num_heads = self.scheme.num_heads
55-
size = get_matrix_size(module, args.location, num_heads)
54+
size = get_matrix_size(module, args.location, self.scheme.head_dim)
5655
dtype = module.weight.dtype
5756
device = get_offloaded_device(module)
5857

5958
weight = self.weights[size, dtype, device]
6059
if args.inverse:
6160
weight = self.inverses[weight]
6261

63-
return RandomMatrixTransform(weight, args, num_heads)
62+
return RandomMatrixTransform(weight, args)
6463

6564
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
6665
# TODO: verify that weight is invertible (has non-zero determinant)
@@ -75,22 +74,17 @@ def _create_inverse(self, weight: Parameter) -> Parameter:
7574

7675

7776
class RandomMatrixTransform(TransformBase):
78-
def __init__(self, weight: Tensor, args: TransformArgs, num_heads: Optional[int]):
77+
def __init__(self, weight: Tensor, args: TransformArgs):
7978
super().__init__()
8079
self.weight = weight # is an inverse if args.inverse
8180
self.args = args
82-
self.num_heads = num_heads
8381

8482
def forward(self, value: Tensor) -> Parameter:
85-
return apply_transform_weight(
86-
self.weight, value, self.args.location, self.num_heads
87-
)
83+
return apply_transform_weight(self.weight, value, self.args.location)
8884

8985
def right_inverse(self, value: Tensor) -> Tensor:
9086
inverse = high_precision_invert(self.weight)
91-
return apply_transform_weight(
92-
inverse, value, self.args.location, self.num_heads
93-
)
87+
return apply_transform_weight(inverse, value, self.args.location)
9488

9589

9690
def high_precision_invert(weight: Tensor) -> Tensor:

src/compressed_tensors/transform/transform_scheme.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,4 @@ class TransformScheme(BaseModel):
4040
apply: List[TransformArgs] = Field(default_factory=list)
4141
randomize: bool = Field(default=False)
4242
requires_grad: bool = Field(default=False)
43-
num_heads: Optional[int] = Field(default=None)
43+
head_dim: Optional[int] = Field(default=None)

src/compressed_tensors/transform/utils/matrix.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,38 +24,36 @@
2424
def get_matrix_size(
2525
module: torch.nn.Module,
2626
location: TransformLocation,
27-
num_heads: Optional[int] = None,
27+
head_dim: Optional[int] = None,
2828
) -> int:
2929
"""
3030
Determine the size of a matrix given its location on the module
3131
3232
:param module: module that matrix will be applied to
3333
:param location: location on module
34+
:TODO head_dim:
3435
:return: size of matrix
3536
"""
3637
assert isinstance(module, torch.nn.Linear)
3738

38-
if location in ("input", TransformLocation.WEIGHT_INPUT):
39+
if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT):
3940
size = module.in_features
4041
else:
4142
size = module.out_features
4243

43-
if num_heads is not None:
44-
assert size % num_heads == 0
45-
size = size // num_heads
44+
if head_dim is not None:
45+
assert size % head_dim == 0
46+
return head_dim
4647

47-
return size
48+
else:
49+
return size
4850

4951

5052
def apply_transform_weight(
5153
weight: torch.Tensor,
5254
value: torch.Tensor,
5355
location: TransformLocation,
54-
num_heads: Optional[int] = None,
5556
) -> torch.Tensor:
56-
if num_heads is not None:
57-
weight = weight.repeat((num_heads, num_heads))
58-
5957
return apply_transform_weight_linear(weight, value, location)
6058

6159

@@ -99,17 +97,31 @@ def apply_transform_weight_linear(
9997
:param location: determines how weight should be applied
10098
:return: value after transform weight has been applied
10199
"""
100+
value_shape = value.shape
101+
weight_size = weight.shape[0]
102+
assert weight.shape[0] == weight.shape[1]
103+
102104
if location == TransformLocation.INPUT:
103-
return value @ weight
105+
num_heads = value_shape[1] // weight_size
106+
value = value.reshape(value_shape[0], num_heads, weight_size)
107+
ret = value @ weight
104108

105109
elif location == TransformLocation.WEIGHT_INPUT:
106-
return value @ weight.T
110+
num_heads = value_shape[1] // weight_size
111+
value = value.reshape(value_shape[0], num_heads, weight_size)
112+
ret = value @ weight.T
107113

108114
elif location == TransformLocation.WEIGHT_OUTPUT:
109-
return weight.T @ value
115+
num_heads = value_shape[0] // weight_size
116+
value = value.reshape(num_heads, weight_size, value_shape[1])
117+
ret = weight.T @ value
110118

111119
elif location == TransformLocation.OUTPUT:
112-
return value @ weight
120+
num_heads = value_shape[1] // weight_size
121+
value = value.reshape(value_shape[0], num_heads, weight_size)
122+
ret = value @ weight
113123

114124
else:
115125
raise NotImplementedError(f"{location} has not been implemented yet")
126+
127+
return ret.reshape(value_shape)

tests/test_transform/conftest.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,62 @@ def forward(self, x):
3333
return x
3434

3535

36+
class MockAttention(torch.nn.Module):
37+
def __init__(
38+
self, hidden_size: int, num_attention_heads: int, num_key_value_heads: int
39+
):
40+
super().__init__()
41+
self.num_attention_heads = num_attention_heads
42+
self.num_key_value_heads = num_key_value_heads
43+
44+
self.num_key_value_groups = num_attention_heads // num_key_value_heads
45+
self.head_dim = hidden_size // num_attention_heads
46+
self.scaling = self.head_dim**-0.5
47+
48+
self.q_proj = torch.nn.Linear(hidden_size, hidden_size, bias=False)
49+
self.k_proj = torch.nn.Linear(
50+
hidden_size, num_key_value_heads * self.head_dim, bias=False
51+
)
52+
self.v_proj = torch.nn.Linear(
53+
hidden_size, num_key_value_heads * self.head_dim, bias=False
54+
)
55+
self.o_proj = torch.nn.Linear(hidden_size, hidden_size, bias=False)
56+
57+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
58+
batch_size, seq_len, hidden_size = hidden_states.shape
59+
hidden_shape = (batch_size, seq_len, -1, self.head_dim)
60+
61+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
62+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
63+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
64+
65+
key_states = self.repeat_kv(key_states, self.num_key_value_groups)
66+
value_states = self.repeat_kv(value_states, self.num_key_value_groups)
67+
68+
attn_weights = (
69+
torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
70+
)
71+
72+
attn_weights = torch.nn.functional.softmax(
73+
attn_weights, dim=-1, dtype=torch.float32
74+
).to(query_states.dtype)
75+
attn_output = torch.matmul(attn_weights, value_states)
76+
attn_output = attn_output.transpose(1, 2).contiguous()
77+
78+
attn_output = attn_output.reshape((batch_size, seq_len, -1)).contiguous()
79+
80+
return self.o_proj(attn_output)
81+
82+
def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
83+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
84+
if n_rep == 1:
85+
return hidden_states
86+
hidden_states = hidden_states[:, :, None, :, :].expand(
87+
batch, num_key_value_heads, n_rep, slen, head_dim
88+
)
89+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
90+
91+
3692
@pytest.fixture(scope="function")
3793
def model_apply():
3894
model = TransformableModel(2, 4, 8, 16, 32, 64)

tests/test_transform/factory/test_correctness.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
apply_transform_config,
2323
)
2424
from compressed_tensors.utils import offloaded_dispatch
25+
from tests.test_transform.conftest import MockAttention
2526
from tests.testing_utils import requires_accelerate, requires_gpu
2627

2728

@@ -87,3 +88,77 @@ def test_correctness_model(type, randomized, model_apply, offload=False):
8788
@pytest.mark.parametrize("randomized", (True, False))
8889
def test_correctness_model_offload(type, randomized, model_apply):
8990
test_correctness_model(type, randomized, model_apply, offload=True)
91+
92+
93+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
94+
@pytest.mark.parametrize("randomized", (True, False))
95+
@pytest.mark.parametrize("head_dim", (16, 32))
96+
def test_correctness_heads(type, randomized, head_dim, offload=False):
97+
hidden_size = 64
98+
99+
model = torch.nn.ModuleDict(
100+
{
101+
"v_proj": torch.nn.Linear(hidden_size, hidden_size, bias=False),
102+
"o_proj": torch.nn.Linear(hidden_size, hidden_size, bias=False),
103+
}
104+
)
105+
106+
input = torch.rand(17, 5, hidden_size)
107+
true_output = model.o_proj(model.v_proj(input))
108+
109+
config = TransformConfig(
110+
config_groups={
111+
"": TransformScheme(
112+
type=type,
113+
randomized=randomized,
114+
head_dim=head_dim,
115+
apply=[
116+
TransformArgs(targets="v_proj", location="weight_output"),
117+
TransformArgs(
118+
targets="o_proj", location="weight_input", inverse=True
119+
),
120+
],
121+
)
122+
}
123+
)
124+
apply_transform_config(model, config)
125+
126+
output = model.o_proj(model.v_proj(input))
127+
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
128+
129+
130+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
131+
@pytest.mark.parametrize("randomized", (True, False))
132+
@pytest.mark.parametrize("head_dim", (8,)) # (8, 16))
133+
def test_correctness_attention_heads(type, randomized, head_dim, offload=False):
134+
hidden_size = 4096
135+
num_attention_heads = 32
136+
137+
attention = MockAttention(
138+
hidden_size=hidden_size,
139+
num_attention_heads=num_attention_heads,
140+
num_key_value_heads=head_dim,
141+
)
142+
143+
input = torch.rand(17, 5, hidden_size)
144+
true_output = attention(input)
145+
146+
config = TransformConfig(
147+
config_groups={
148+
"": TransformScheme(
149+
type=type,
150+
randomized=randomized,
151+
head_dim=head_dim,
152+
apply=[
153+
TransformArgs(targets="v_proj", location="weight_output"),
154+
TransformArgs(
155+
targets="o_proj", location="weight_input", inverse=True
156+
),
157+
],
158+
)
159+
}
160+
)
161+
apply_transform_config(attention, config)
162+
163+
output = attention(input)
164+
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)

0 commit comments

Comments
 (0)