Skip to content

Commit da36ca6

Browse files
committed
wip concept
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 52e7074 commit da36ca6

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional
16+
17+
import torch
18+
from compressed_tensors.transform import TransformBase, TransformLocation
19+
from compressed_tensors.utils import patch_attr
20+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
21+
22+
23+
"""
24+
Attention interfaces are functions with the following signature
25+
module, query, key, value, attention_mask, scaling, dropout, **kwargs
26+
They're gotten `from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS`
27+
28+
Idea: Yield a custom attention function which injects
29+
30+
Pros: relatively simple
31+
Cons: ordering is hard, since submodules aren't ordered; a little harder if you want
32+
to do stuff like attention output hooks
33+
We can just disable multiple attention transforms for now
34+
"""
35+
36+
original_get_item = ALL_ATTENTION_FUNCTIONS.__getitem__
37+
38+
39+
def make_hooked_attention(key):
40+
def hooked_attention(
41+
module: torch.nn.Module,
42+
query: torch.Tensor,
43+
key: torch.Tensor,
44+
value: torch.Tensor,
45+
attention_mask: Optional[torch.Tensor],
46+
scaling: float,
47+
dropout: float = 0.0,
48+
**kwargs,
49+
):
50+
for submodule in module.children():
51+
if isinstance(submodule, TransformBase):
52+
if TransformBase.args.location == TransformLocation.Q_ATTN:
53+
query = submodule(query)
54+
55+
if TransformBase.args.location == TransformLocation.K_CACHE:
56+
key = submodule(key)
57+
58+
return original_get_item(key)(
59+
module, query, key, value, attention_mask, scaling, dropout, **kwargs
60+
)
61+
62+
return hooked_attention
63+
64+
65+
_cache = {}
66+
67+
68+
def patched_get_item(self, key):
69+
if key not in _cache:
70+
_cache[key] = make_hooked_attention(key)
71+
72+
return _cache[key]
73+
74+
75+
patch_attr(ALL_ATTENTION_FUNCTIONS, "__getitem__", patched_get_item)

0 commit comments

Comments
 (0)