Skip to content

Commit bb50359

Browse files
committed
feat(modules): add linformer self attention
1 parent 8dcfac0 commit bb50359

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
from .attention_ops import compute_mha
6+
from .base_attention import BaseSelfAttention
7+
8+
__all__ = ["LinformerAttention"]
9+
10+
11+
class LinformerAttention(BaseSelfAttention):
12+
def __init__(
13+
self,
14+
seq_len: int,
15+
head_dim: int,
16+
num_heads: int,
17+
k: int = None,
18+
how: str = "basic",
19+
slice_size: int = None,
20+
**kwargs
21+
) -> None:
22+
"""Linformer attention mechanism.
23+
24+
Linformer: Self-Attention with Linear Complexity
25+
- https://arxiv.org/abs/2006.04768v2
26+
27+
Adapted from xformers library
28+
29+
NOTE: Weirdly, even when computing linformer attention with xformers
30+
`memory_efficient_attention`, linformer needs more memory for long sequences
31+
(due to the linear layers) than computing exact `memory_efficient_attention`.
32+
33+
Parameters
34+
----------
35+
seq_len : int
36+
The length of the sequence. (For per-pixel patches H*W).
37+
head_dim : int
38+
Out dim per attention head.
39+
num_heads : int
40+
Number of heads.
41+
k : int, optional
42+
Divisor for key and value matrices to get low-rank attention matrix.
43+
how : str, default="basic"
44+
How to compute the self-attention matrix.
45+
One of ("basic", "flash", "slice", "memeff", "slice_memeff").
46+
"basic": the normal O(N^2) self attention.
47+
"flash": the flash attention (by xformers library),
48+
"slice": batch sliced attention operation to save mem.
49+
"memeff": xformers.memory_efficient_attention.
50+
"slice_memeff": Conmbine slicing and memory_efficient_attention.
51+
slice_size, int, optional
52+
The size of the slice. Used only if `how in ('slice', 'slice_memeff)`.
53+
"""
54+
super().__init__(
55+
head_dim=head_dim,
56+
num_heads=num_heads,
57+
how=how,
58+
slice_size=slice_size,
59+
)
60+
61+
if k is None:
62+
k = seq_len // 4
63+
64+
self.k = k
65+
self.E = nn.Linear(seq_len, k, bias=False)
66+
self.F = nn.Linear(seq_len, k, bias=False)
67+
self.seq_len = seq_len
68+
69+
def forward(
70+
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs
71+
) -> torch.Tensor:
72+
"""Forward pass of the linformer attention mechanism."""
73+
padding = 0
74+
if query.shape[1] < self.seq_len:
75+
padding = self.seq_len - query.shape[1]
76+
pad_dims = (0, 0, 0, padding)
77+
query = F.pad(query, pad_dims)
78+
key = F.pad(key, pad_dims)
79+
value = F.pad(value, pad_dims)
80+
81+
key_proj = self.E(key.transpose(-1, -2)).transpose(-1, -2).contiguous()
82+
value_proj = self.F(value.transpose(-1, -2)).transpose(-1, -2).contiguous()
83+
84+
out = compute_mha(
85+
query,
86+
key_proj,
87+
value_proj,
88+
self.how,
89+
slice_size=self.slice_size, # used only for slice-att
90+
num_heads=self.num_heads, # used only for slice-att
91+
proj_channels=self.proj_channels, # used only for slice-att
92+
)
93+
94+
return out[:, :-padding, :] if padding > 0 else out

0 commit comments

Comments
 (0)