Skip to content

Commit 2bcc590

Browse files
committed
add ability to set flash = True on RIN
1 parent a90bc8e commit 2bcc590

File tree

4 files changed

+156
-7
lines changed

4 files changed

+156
-7
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,12 @@ sampled_images.shape # (4, 3, 128, 128)
149149
year = {2023}
150150
}
151151
```
152+
153+
```bibtex
154+
@inproceedings{dao2022flashattention,
155+
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
156+
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
157+
booktitle = {Advances in Neural Information Processing Systems},
158+
year = {2022}
159+
}
160+
```

rin_pytorch/attend.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
from functools import wraps
2+
from packaging import version
3+
from collections import namedtuple
4+
5+
import torch
6+
from torch import nn, einsum
7+
import torch.nn.functional as F
8+
9+
from einops import rearrange, reduce
10+
11+
# constants
12+
13+
FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
14+
15+
# helpers
16+
17+
def exists(val):
18+
return val is not None
19+
20+
def once(fn):
21+
called = False
22+
@wraps(fn)
23+
def inner(x):
24+
nonlocal called
25+
if called:
26+
return
27+
called = True
28+
return fn(x)
29+
return inner
30+
31+
print_once = once(print)
32+
33+
# main class
34+
35+
class Attend(nn.Module):
36+
def __init__(
37+
self,
38+
dropout = 0.,
39+
flash = False,
40+
l2_dist = False
41+
):
42+
super().__init__()
43+
assert not (flash and l2_dist), 'flash attention is not compatible with l2 distance'
44+
self.l2_dist = l2_dist
45+
46+
self.dropout = dropout
47+
self.attn_dropout = nn.Dropout(dropout)
48+
49+
self.flash = flash
50+
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
51+
52+
# determine efficient attention configs for cuda and cpu
53+
54+
self.cpu_config = FlashAttentionConfig(True, True, True)
55+
self.cuda_config = None
56+
57+
if not torch.cuda.is_available() or not flash:
58+
return
59+
60+
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
61+
62+
if device_properties.major == 8 and device_properties.minor == 0:
63+
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
64+
self.cuda_config = FlashAttentionConfig(True, False, False)
65+
else:
66+
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
67+
self.cuda_config = FlashAttentionConfig(False, True, True)
68+
69+
def flash_attn(self, q, k, v, mask = None):
70+
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
71+
72+
# Check if mask exists and expand to compatible shape
73+
# The mask is B L, so it would have to be expanded to B H N L
74+
75+
if exists(mask):
76+
mask = mask.expand(-1, heads, q_len, -1)
77+
78+
# Check if there is a compatible device for flash attention
79+
80+
config = self.cuda_config if is_cuda else self.cpu_config
81+
82+
# pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
83+
84+
with torch.backends.cuda.sdp_kernel(**config._asdict()):
85+
out = F.scaled_dot_product_attention(
86+
q, k, v,
87+
attn_mask = mask,
88+
dropout_p = self.dropout if self.training else 0.
89+
)
90+
91+
return out
92+
93+
def forward(self, q, k, v, mask = None):
94+
"""
95+
einstein notation
96+
b - batch
97+
h - heads
98+
n, i, j - sequence length (base sequence length, source, target)
99+
d - feature dimension
100+
"""
101+
102+
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
103+
104+
scale = q.shape[-1] ** -0.5
105+
106+
if exists(mask) and mask.ndim != 4:
107+
mask = rearrange(mask, 'b j -> b 1 1 j')
108+
109+
if self.flash:
110+
return self.flash_attn(q, k, v, mask = mask)
111+
112+
# similarity
113+
114+
sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
115+
116+
# l2 distance
117+
118+
if self.l2_dist:
119+
# -cdist squared == (-q^2 + 2qk - k^2)
120+
# so simply work off the qk above
121+
q_squared = reduce(q ** 2, 'b h i d -> b h i 1', 'sum')
122+
k_squared = reduce(k ** 2, 'b h j d -> b h 1 j', 'sum')
123+
sim = sim * 2 - q_squared - k_squared
124+
125+
# key padding mask
126+
127+
if exists(mask):
128+
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
129+
130+
# attention
131+
132+
attn = sim.softmax(dim=-1)
133+
attn = self.attn_dropout(attn)
134+
135+
# aggregate values
136+
137+
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
138+
139+
return out

rin_pytorch/rin_pytorch.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from einops import rearrange, reduce, repeat
1919
from einops.layers.torch import Rearrange
2020

21+
from rin_pytorch.attend import Attend
22+
2123
from PIL import Image
2224
from tqdm.auto import tqdm
2325
from ema_pytorch import EMA
@@ -166,7 +168,8 @@ def __init__(
166168
dim_head = 32,
167169
norm = False,
168170
norm_context = False,
169-
time_cond_dim = None
171+
time_cond_dim = None,
172+
flash = False
170173
):
171174
super().__init__()
172175
hidden_dim = dim_head * heads
@@ -194,6 +197,8 @@ def __init__(
194197
self.to_kv = nn.Linear(dim_context, hidden_dim * 2, bias = False)
195198
self.to_out = nn.Linear(hidden_dim, dim, bias = False)
196199

200+
self.attend = Attend(flash = flash)
201+
197202
def forward(
198203
self,
199204
x,
@@ -217,12 +222,8 @@ def forward(
217222
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
218223
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
219224

220-
q = q * self.scale
221-
222-
sim = einsum('b h i d, b h j d -> b h i j', q, k)
223-
attn = sim.softmax(dim = -1)
225+
out = self.attend(q, k, v)
224226

225-
out = einsum('b h i j, b h j d -> b h i d', attn, v)
226227
out = rearrange(out, 'b h n d -> b n (h d)')
227228
return self.to_out(out)
228229

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'RIN-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.7.4',
6+
version = '0.7.5',
77
license='MIT',
88
description = 'RIN - Recurrent Interface Network - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)