26
26
from attn_gym .mods import generate_alibi_bias , generate_tanh_softcap
27
27
28
28
29
+ AVAILABLE_EXAMPLES = {
30
+ "causal" : lambda : test_mask (mask_mod = causal_mask ),
31
+ "alibi" : lambda : test_mask (score_mod = generate_alibi_bias (16 ), skip_correctness = True ),
32
+ "sliding_window" : lambda : test_mask (mask_mod = generate_sliding_window (window_size = 1024 )),
33
+ "prefix_lm" : lambda : test_mask (mask_mod = generate_prefix_lm_mask (prefix_length = 1024 )),
34
+ "document" : lambda : run_document_masking (max_seq_len = 32768 , num_docs = 12 ),
35
+ "softcap" : lambda : test_mask (
36
+ score_mod = generate_tanh_softcap (30 , approx = False ), skip_correctness = True
37
+ ),
38
+ "softcap_approx" : lambda : test_mask (
39
+ score_mod = generate_tanh_softcap (30 , approx = True ), skip_correctness = True
40
+ ),
41
+ }
42
+
43
+
29
44
torch .set_default_device ("cuda" )
30
45
torch .manual_seed (0 )
31
46
@@ -97,19 +112,15 @@ def test_mask(
97
112
causal_fav2_flops = 0.5 * B * H * D * S * S
98
113
flops = density * B * H * D * S * S
99
114
100
- # Forward pass
101
- causal_fa2_time = do_bench (causal_fa2 )
102
- sdpa_mask_time = do_bench (sdpa_mask )
103
- flex_ms = do_bench (flex_attention_call )
115
+ times = []
116
+ for attn in (causal_fa2 , sdpa_mask , flex_attention_call ):
117
+ fwd_time = do_bench (attn )
118
+ fwd_out = attn ()
119
+ bwd_time = do_bench (lambda : fwd_out .backward (gradOut , retain_graph = True )) # noqa: F821
120
+ times .append ((fwd_time , bwd_time ))
104
121
105
- # Backward pass
106
- causal_fa2_out = causal_fa2 ()
107
- sdpa_mask_out = sdpa_mask ()
108
- flex_out = flex_attention_call ()
109
-
110
- causal_fa2_bw_time = do_bench (lambda : causal_fa2_out .backward (gradOut , retain_graph = True ))
111
- sdpa_mask_bw_time = do_bench (lambda : sdpa_mask_out .backward (gradOut , retain_graph = True ))
112
- flex_bw_ms = do_bench (lambda : flex_out .backward (gradOut , retain_graph = True ))
122
+ del fwd_out
123
+ torch .cuda .empty_cache ()
113
124
114
125
print_header (
115
126
f"{ score_mod .__name__ if score_mod is not None else mask_mod .__name__ } " .replace (
@@ -140,6 +151,12 @@ def test_mask(
140
151
torch .testing .assert_close (flex , sdpa_mask , atol = 1e-1 , rtol = 1e-2 )
141
152
142
153
print ("Correctness check passed ✅" )
154
+
155
+ (
156
+ (causal_fa2_time , causal_fa2_bw_time ),
157
+ (sdpa_mask_time , sdpa_mask_bw_time ),
158
+ (flex_ms , flex_bw_ms ),
159
+ ) = times
143
160
# Usage in your results formatting:
144
161
results = [
145
162
[
@@ -210,28 +227,16 @@ def main(examples: List[str] = ["all"]):
210
227
Args:
211
228
examples: List of examples to run. If "all" is specified, all examples will be run.
212
229
"""
213
- available_examples = {
214
- "causal" : lambda : test_mask (mask_mod = causal_mask ),
215
- "alibi" : lambda : test_mask (score_mod = generate_alibi_bias (16 ), skip_correctness = True ),
216
- "sliding_window" : lambda : test_mask (mask_mod = generate_sliding_window (window_size = 1024 )),
217
- "prefix_lm" : lambda : test_mask (mask_mod = generate_prefix_lm_mask (prefix_length = 1024 )),
218
- "document" : lambda : run_document_masking (max_seq_len = 32768 , num_docs = 12 ),
219
- "softcap" : lambda : test_mask (
220
- score_mod = generate_tanh_softcap (30 , approx = False ), skip_correctness = True
221
- ),
222
- "softcap_approx" : lambda : test_mask (
223
- score_mod = generate_tanh_softcap (30 , approx = True ), skip_correctness = True
224
- ),
225
- }
226
230
227
231
if "all" in examples :
228
- ex_to_run = list (available_examples .keys ())
232
+ ex_to_run = list (AVAILABLE_EXAMPLES .keys ())
229
233
else :
230
234
ex_to_run = examples
231
235
232
236
for ex in ex_to_run :
233
- if ex in available_examples :
234
- available_examples [ex ]()
237
+ if ex in AVAILABLE_EXAMPLES :
238
+ AVAILABLE_EXAMPLES [ex ]()
239
+ torch .cuda .empty_cache ()
235
240
else :
236
241
print (f"Warning: Unknown example key '{ ex } '. Skipping." )
237
242
@@ -248,8 +253,9 @@ def main(examples: List[str] = ["all"]):
248
253
nargs = "+" ,
249
254
default = ["all" ],
250
255
help = "List of examples to run. Use space to separate multiple examples. "
251
- "Available options: causal, alibi, sliding_window, prefix_lm, "
252
- "document, softcap, softcap_approx, or 'all' to run all examples." ,
256
+ "Available options: "
257
+ + ", " .join (sorted (AVAILABLE_EXAMPLES .keys ()))
258
+ + ", or 'all' to run all examples." ,
253
259
)
254
260
255
261
args = parser .parse_args ()
0 commit comments