1
1
import torch
2
- from torch .autograd import grad
3
2
from torch .nn .attention .flex_attention import flex_attention , create_block_mask
4
3
import pytest
5
- from functools import partial
6
4
from attn_gym .masks import generate_natten , generate_tiled_natten , generate_morton_natten
7
5
from attn_gym .masks .natten import morton_decode , morton_encode
8
6
9
7
10
-
11
8
def run_natten (
12
- mask = None ,
13
- encoder = None ,
14
- decoder = None ,
15
- query = None ,
16
- key = None ,
17
- value = None ,
18
- gradOut = None ,
9
+ mask = None ,
10
+ encoder = None ,
11
+ decoder = None ,
12
+ query = None ,
13
+ key = None ,
14
+ value = None ,
15
+ gradOut = None ,
19
16
print_mask = True ,
20
17
):
21
18
B , H , W , _ , D = query .shape
22
19
if decoder :
23
- permuter_x , permuter_y = decoder (torch .arange (W * W ))
24
- permuter_index = permuter_x * W + permuter_y
25
- q = query [:, :, permuter_x , permuter_y , :].clone ().detach ().requires_grad_ (query .requires_grad )
20
+ permuter_x , permuter_y = decoder (torch .arange (W * W ))
21
+ q = (
22
+ query [:, :, permuter_x , permuter_y , :]
23
+ .clone ()
24
+ .detach ()
25
+ .requires_grad_ (query .requires_grad )
26
+ )
26
27
k = key [:, :, permuter_x , permuter_y , :].clone ().detach ().requires_grad_ (key .requires_grad )
27
- v = value [:, :, permuter_x , permuter_y , :].clone ().detach ().requires_grad_ (value .requires_grad )
28
+ v = (
29
+ value [:, :, permuter_x , permuter_y , :]
30
+ .clone ()
31
+ .detach ()
32
+ .requires_grad_ (value .requires_grad )
33
+ )
28
34
dO = gradOut [:, :, permuter_x , permuter_y , :]
29
- else :
35
+ else :
30
36
q = query .flatten (2 , 3 ).clone ().detach ().requires_grad_ (query .requires_grad )
31
37
k = key .flatten (2 , 3 ).clone ().detach ().requires_grad_ (key .requires_grad )
32
38
v = value .flatten (2 , 3 ).clone ().detach ().requires_grad_ (value .requires_grad )
33
39
dO = gradOut .flatten (2 , 3 )
34
- block_mask = create_block_mask (mask , 1 , 1 , W * W , W * W , device = query .device )
40
+ block_mask = create_block_mask (mask , 1 , 1 , W * W , W * W , device = query .device )
35
41
if print_mask :
36
42
print (f"\n Block Mask:\n { block_mask } " )
37
-
43
+
38
44
flex_attention_compiled = torch .compile (flex_attention , dynamic = False )
39
45
out = flex_attention_compiled (q , k , v , block_mask = block_mask )
40
-
46
+
41
47
out .backward (dO )
42
-
43
- if encoder :
44
- i_x = torch .arange (W )[:, None ].broadcast_to (W , W ).flatten ()
45
- i_y = torch .arange (W )[None , :].broadcast_to (W , W ).flatten ()
48
+
49
+ if encoder :
50
+ i_x = torch .arange (W )[:, None ].broadcast_to (W , W ).flatten ()
51
+ i_y = torch .arange (W )[None , :].broadcast_to (W , W ).flatten ()
46
52
depermuter = encoder (i_x , i_y )
47
53
out = out [:, :, depermuter , :].reshape (B , H , W , W , D )
48
54
q_grad = q .grad [:, :, depermuter , :].reshape (B , H , W , W , D )
49
55
k_grad = k .grad [:, :, depermuter , :].reshape (B , H , W , W , D )
50
56
v_grad = v .grad [:, :, depermuter , :].reshape (B , H , W , W , D )
51
57
results = [out , q_grad , k_grad , v_grad ]
52
58
else :
53
- out = out .reshape (B , H , W , W , D )
59
+ out = out .reshape (B , H , W , W , D )
54
60
q_grad = q .grad .reshape (B , H , W , W , D )
55
61
k_grad = k .grad .reshape (B , H , W , W , D )
56
62
v_grad = v .grad .reshape (B , H , W , W , D )
57
63
results = [out , q_grad , k_grad , v_grad ]
58
-
64
+
59
65
del q , k , v , dO
60
-
66
+
61
67
return results
62
68
63
69
@@ -69,25 +75,21 @@ def test_natten_masks(
69
75
K_W = 13 ,
70
76
T_W = 8 ,
71
77
print_mask = True ,
72
- ):
73
- query = torch .randn (
74
- B , H , W , W , D , device = "cuda" , dtype = torch .float16 , requires_grad = True
75
- )
76
- key = torch .randn (
77
- B , H , W , W , D , device = "cuda" , dtype = torch .float16 , requires_grad = True
78
- )
79
- value = torch .randn (
80
- B , H , W , W , D , device = "cuda" , dtype = torch .float16 , requires_grad = True
81
- )
78
+ ):
79
+ query = torch .randn (B , H , W , W , D , device = "cuda" , dtype = torch .float16 , requires_grad = True )
80
+ key = torch .randn (B , H , W , W , D , device = "cuda" , dtype = torch .float16 , requires_grad = True )
81
+ value = torch .randn (B , H , W , W , D , device = "cuda" , dtype = torch .float16 , requires_grad = True )
82
82
gradOut = torch .randn (B , H , W , W , D , device = "cuda" , dtype = torch .float16 )
83
-
84
-
83
+
85
84
# Run naive NA
86
85
naive_mask = generate_natten (W , W , K_W , K_W )
87
- naive_results = run_natten (mask = naive_mask , query = query , key = key , value = value , gradOut = gradOut , print_mask = print_mask )
88
-
86
+ naive_results = run_natten (
87
+ mask = naive_mask , query = query , key = key , value = value , gradOut = gradOut , print_mask = print_mask
88
+ )
89
+
89
90
# Run tiled NA
90
91
T_H = T_W
92
+
91
93
def tiled_encoder (x , y ):
92
94
"""
93
95
Map 2-D coordinates to 1-D index for static tiles of T_H x T_W.
@@ -106,14 +108,33 @@ def tiled_decoder(idx):
106
108
t_x , t_y = t_id // (W // T_W ), t_id % (W // T_W )
107
109
t_offset = idx % (T_H * T_W )
108
110
i_x , i_y = t_offset // T_W , t_offset % T_W
109
- return t_x * T_W + i_x , t_y * T_H + i_y
111
+ return t_x * T_W + i_x , t_y * T_H + i_y
112
+
110
113
tiled_mask = generate_tiled_natten (W , W , K_W , K_W , T_W , T_H )
111
- tiled_results = run_natten (mask = tiled_mask , encoder = tiled_encoder , decoder = tiled_decoder , query = query , key = key , value = value , gradOut = gradOut , print_mask = print_mask )
112
-
114
+ tiled_results = run_natten (
115
+ mask = tiled_mask ,
116
+ encoder = tiled_encoder ,
117
+ decoder = tiled_decoder ,
118
+ query = query ,
119
+ key = key ,
120
+ value = value ,
121
+ gradOut = gradOut ,
122
+ print_mask = print_mask ,
123
+ )
124
+
113
125
# Run morton NA
114
126
morton_mask = generate_morton_natten (W , W , K_W , K_W )
115
- morton_results = run_natten (mask = morton_mask , encoder = morton_encode , decoder = morton_decode , query = query , key = key , value = value , gradOut = gradOut , print_mask = print_mask )
116
-
127
+ morton_results = run_natten (
128
+ mask = morton_mask ,
129
+ encoder = morton_encode ,
130
+ decoder = morton_decode ,
131
+ query = query ,
132
+ key = key ,
133
+ value = value ,
134
+ gradOut = gradOut ,
135
+ print_mask = print_mask ,
136
+ )
137
+
117
138
for naive , tiled , morton in zip (naive_results , tiled_results , morton_results ):
118
139
torch .testing .assert_close (naive , tiled , atol = 1e-1 , rtol = 1e-2 )
119
140
print ("Tiled NATTEN: Correctness check passed ✅" )
@@ -124,5 +145,6 @@ def tiled_decoder(idx):
124
145
del query , key , value , gradOut , naive_results , tiled_results
125
146
torch .cuda .empty_cache ()
126
147
148
+
127
149
if __name__ == "__main__" :
128
- pytest .main ([__file__ ])
150
+ pytest .main ([__file__ ])
0 commit comments