Skip to content

Commit dabfa8b

Browse files
authored
Fix Ruff (#90)
stack-info: PR: #90, branch: drisspg/stack/3
1 parent b32a6b5 commit dabfa8b

File tree

3 files changed

+79
-52
lines changed

3 files changed

+79
-52
lines changed

attn_gym/masks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
from attn_gym.masks.prefix_lm import generate_prefix_lm_mask
44
from attn_gym.masks.document_mask import generate_doc_mask_mod
55
from attn_gym.masks.dilated_sliding_window import generate_dilated_sliding_window
6-
from attn_gym.masks.natten import generate_natten, generate_tiled_natten, generate_morton_natten
6+
from attn_gym.masks.natten import generate_natten, generate_tiled_natten, generate_morton_natten

attn_gym/masks/natten.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def natten_mask_mod(
4242
natten_mask_mod.__name__ = f"natten_c{canvas_w}x{canvas_h}_k{kernel_w}x{kernel_h}"
4343
return natten_mask_mod
4444

45+
4546
def generate_tiled_natten(
4647
W: int,
4748
H: int,
@@ -68,7 +69,7 @@ def get_x_y_tiled(idx: IntTensor) -> Tuple[IntTensor, IntTensor]:
6869
t_x, t_y = t_id // (W // T_W), t_id % (W // T_W)
6970
t_offset = idx % (T_H * T_W)
7071
i_x, i_y = t_offset // T_W, t_offset % T_W
71-
return t_x*T_W + i_x, t_y*T_H + i_y
72+
return t_x * T_W + i_x, t_y * T_H + i_y
7273

7374
def tiled_natten_mask(
7475
b: IntTensor,
@@ -87,6 +88,7 @@ def tiled_natten_mask(
8788
tiled_natten_mask.__name__ = f"tiled_natten_c{W}x{H}_k{K_W}x{K_H}_t{T_W}x{T_H}"
8889
return tiled_natten_mask
8990

91+
9092
def interleave_bits_32(x):
9193
"""
9294
Interleave the bits of a 16-bit integer x, producing a 32-bit integer
@@ -99,6 +101,7 @@ def interleave_bits_32(x):
99101
x = (x | (x << 1)) & 0x55555555
100102
return x
101103

104+
102105
def morton_encode(x, y):
103106
"""
104107
Encode 2D coordinates (x, y) into a Morton code (Z-order curve index).
@@ -112,6 +115,7 @@ def morton_encode(x, y):
112115
"""
113116
return (interleave_bits_32(y) << 1) | interleave_bits_32(x)
114117

118+
115119
def deinterleave_bits_32(code):
116120
"""
117121
Deinterleave bits to retrieve the original 16-bit integer.
@@ -123,6 +127,7 @@ def deinterleave_bits_32(code):
123127
code = (code | (code >> 8)) & 0x0000FFFF
124128
return code
125129

130+
126131
def morton_decode(code):
127132
"""
128133
Decode a Morton code to retrieve the original 2D coordinates (x, y).
@@ -144,13 +149,14 @@ def generate_morton_natten(
144149
kernel_w: int,
145150
kernel_h: int,
146151
) -> _mask_mod_signature:
147-
"""Generates a NATTEN attention mask with a given kernel size under morton curve layout.
152+
"""Generates a NATTEN attention mask with a given kernel size under morton curve layout.
148153
Args:
149154
canvas_w: The width of the canvas.
150155
canvas_h: The height of the canvas.
151156
kernel_w: The width of the kernel.
152157
kernel_h: The height of the kernel.
153158
"""
159+
154160
def natten_mask_mod(
155161
b: IntTensor,
156162
h: IntTensor,
@@ -170,6 +176,7 @@ def natten_mask_mod(
170176
natten_mask_mod.__name__ = f"morton_natten_c{canvas_w}x{canvas_h}_k{kernel_w}x{kernel_h}"
171177
return natten_mask_mod
172178

179+
173180
def main(device: str = "cpu"):
174181
"""Visualize the attention scores of NATTEN mask mod.
175182
Note: a more complete implementation of NATTEN would include support for kernel dilation.
@@ -204,8 +211,7 @@ def make_tensor():
204211
device=device,
205212
name=natten_mask.__name__,
206213
)
207-
208-
214+
209215
tiled_natten_mask = generate_tiled_natten(
210216
W=CANVAS_WIDTH,
211217
H=CANVAS_HEIGHT,
@@ -222,8 +228,7 @@ def make_tensor():
222228
device=device,
223229
name=tiled_natten_mask.__name__,
224230
)
225-
226-
231+
227232
morton_natten_mask = generate_morton_natten(
228233
canvas_w=CANVAS_WIDTH,
229234
canvas_h=CANVAS_HEIGHT,

test/test_natten.py

Lines changed: 67 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,69 @@
11
import torch
2-
from torch.autograd import grad
32
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
43
import pytest
5-
from functools import partial
64
from attn_gym.masks import generate_natten, generate_tiled_natten, generate_morton_natten
75
from attn_gym.masks.natten import morton_decode, morton_encode
86

97

10-
118
def run_natten(
12-
mask = None,
13-
encoder = None,
14-
decoder = None,
15-
query = None,
16-
key = None,
17-
value = None,
18-
gradOut = None,
9+
mask=None,
10+
encoder=None,
11+
decoder=None,
12+
query=None,
13+
key=None,
14+
value=None,
15+
gradOut=None,
1916
print_mask=True,
2017
):
2118
B, H, W, _, D = query.shape
2219
if decoder:
23-
permuter_x, permuter_y = decoder(torch.arange(W*W))
24-
permuter_index = permuter_x * W + permuter_y
25-
q = query[:, :, permuter_x, permuter_y, :].clone().detach().requires_grad_(query.requires_grad)
20+
permuter_x, permuter_y = decoder(torch.arange(W * W))
21+
q = (
22+
query[:, :, permuter_x, permuter_y, :]
23+
.clone()
24+
.detach()
25+
.requires_grad_(query.requires_grad)
26+
)
2627
k = key[:, :, permuter_x, permuter_y, :].clone().detach().requires_grad_(key.requires_grad)
27-
v = value[:, :, permuter_x, permuter_y, :].clone().detach().requires_grad_(value.requires_grad)
28+
v = (
29+
value[:, :, permuter_x, permuter_y, :]
30+
.clone()
31+
.detach()
32+
.requires_grad_(value.requires_grad)
33+
)
2834
dO = gradOut[:, :, permuter_x, permuter_y, :]
29-
else:
35+
else:
3036
q = query.flatten(2, 3).clone().detach().requires_grad_(query.requires_grad)
3137
k = key.flatten(2, 3).clone().detach().requires_grad_(key.requires_grad)
3238
v = value.flatten(2, 3).clone().detach().requires_grad_(value.requires_grad)
3339
dO = gradOut.flatten(2, 3)
34-
block_mask = create_block_mask(mask, 1, 1, W*W, W*W, device=query.device)
40+
block_mask = create_block_mask(mask, 1, 1, W * W, W * W, device=query.device)
3541
if print_mask:
3642
print(f"\nBlock Mask:\n{block_mask}")
37-
43+
3844
flex_attention_compiled = torch.compile(flex_attention, dynamic=False)
3945
out = flex_attention_compiled(q, k, v, block_mask=block_mask)
40-
46+
4147
out.backward(dO)
42-
43-
if encoder:
44-
i_x = torch.arange(W)[:, None].broadcast_to(W, W).flatten()
45-
i_y = torch.arange(W)[None, :].broadcast_to(W, W).flatten()
48+
49+
if encoder:
50+
i_x = torch.arange(W)[:, None].broadcast_to(W, W).flatten()
51+
i_y = torch.arange(W)[None, :].broadcast_to(W, W).flatten()
4652
depermuter = encoder(i_x, i_y)
4753
out = out[:, :, depermuter, :].reshape(B, H, W, W, D)
4854
q_grad = q.grad[:, :, depermuter, :].reshape(B, H, W, W, D)
4955
k_grad = k.grad[:, :, depermuter, :].reshape(B, H, W, W, D)
5056
v_grad = v.grad[:, :, depermuter, :].reshape(B, H, W, W, D)
5157
results = [out, q_grad, k_grad, v_grad]
5258
else:
53-
out= out.reshape(B, H, W, W, D)
59+
out = out.reshape(B, H, W, W, D)
5460
q_grad = q.grad.reshape(B, H, W, W, D)
5561
k_grad = k.grad.reshape(B, H, W, W, D)
5662
v_grad = v.grad.reshape(B, H, W, W, D)
5763
results = [out, q_grad, k_grad, v_grad]
58-
64+
5965
del q, k, v, dO
60-
66+
6167
return results
6268

6369

@@ -69,25 +75,21 @@ def test_natten_masks(
6975
K_W=13,
7076
T_W=8,
7177
print_mask=True,
72-
):
73-
query = torch.randn(
74-
B, H, W, W, D, device="cuda", dtype=torch.float16, requires_grad=True
75-
)
76-
key = torch.randn(
77-
B, H, W, W, D, device="cuda", dtype=torch.float16, requires_grad=True
78-
)
79-
value = torch.randn(
80-
B, H, W, W, D, device="cuda", dtype=torch.float16, requires_grad=True
81-
)
78+
):
79+
query = torch.randn(B, H, W, W, D, device="cuda", dtype=torch.float16, requires_grad=True)
80+
key = torch.randn(B, H, W, W, D, device="cuda", dtype=torch.float16, requires_grad=True)
81+
value = torch.randn(B, H, W, W, D, device="cuda", dtype=torch.float16, requires_grad=True)
8282
gradOut = torch.randn(B, H, W, W, D, device="cuda", dtype=torch.float16)
83-
84-
83+
8584
# Run naive NA
8685
naive_mask = generate_natten(W, W, K_W, K_W)
87-
naive_results = run_natten(mask=naive_mask, query=query, key=key, value=value, gradOut=gradOut, print_mask=print_mask)
88-
86+
naive_results = run_natten(
87+
mask=naive_mask, query=query, key=key, value=value, gradOut=gradOut, print_mask=print_mask
88+
)
89+
8990
# Run tiled NA
9091
T_H = T_W
92+
9193
def tiled_encoder(x, y):
9294
"""
9395
Map 2-D coordinates to 1-D index for static tiles of T_H x T_W.
@@ -106,14 +108,33 @@ def tiled_decoder(idx):
106108
t_x, t_y = t_id // (W // T_W), t_id % (W // T_W)
107109
t_offset = idx % (T_H * T_W)
108110
i_x, i_y = t_offset // T_W, t_offset % T_W
109-
return t_x*T_W + i_x, t_y*T_H + i_y
111+
return t_x * T_W + i_x, t_y * T_H + i_y
112+
110113
tiled_mask = generate_tiled_natten(W, W, K_W, K_W, T_W, T_H)
111-
tiled_results = run_natten(mask=tiled_mask, encoder=tiled_encoder, decoder=tiled_decoder, query=query, key=key, value=value, gradOut=gradOut, print_mask=print_mask)
112-
114+
tiled_results = run_natten(
115+
mask=tiled_mask,
116+
encoder=tiled_encoder,
117+
decoder=tiled_decoder,
118+
query=query,
119+
key=key,
120+
value=value,
121+
gradOut=gradOut,
122+
print_mask=print_mask,
123+
)
124+
113125
# Run morton NA
114126
morton_mask = generate_morton_natten(W, W, K_W, K_W)
115-
morton_results = run_natten(mask=morton_mask, encoder=morton_encode, decoder=morton_decode, query=query, key=key, value=value, gradOut=gradOut, print_mask=print_mask)
116-
127+
morton_results = run_natten(
128+
mask=morton_mask,
129+
encoder=morton_encode,
130+
decoder=morton_decode,
131+
query=query,
132+
key=key,
133+
value=value,
134+
gradOut=gradOut,
135+
print_mask=print_mask,
136+
)
137+
117138
for naive, tiled, morton in zip(naive_results, tiled_results, morton_results):
118139
torch.testing.assert_close(naive, tiled, atol=1e-1, rtol=1e-2)
119140
print("Tiled NATTEN: Correctness check passed ✅")
@@ -124,5 +145,6 @@ def tiled_decoder(idx):
124145
del query, key, value, gradOut, naive_results, tiled_results
125146
torch.cuda.empty_cache()
126147

148+
127149
if __name__ == "__main__":
128-
pytest.main([__file__])
150+
pytest.main([__file__])

0 commit comments

Comments
 (0)