Skip to content

Commit a710e18

Browse files
authored
Reduce memory usage (#100)
1 parent 60f0f87 commit a710e18

File tree

1 file changed

+36
-30
lines changed

1 file changed

+36
-30
lines changed

examples/benchmark.py

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,21 @@
2626
from attn_gym.mods import generate_alibi_bias, generate_tanh_softcap
2727

2828

29+
AVAILABLE_EXAMPLES = {
30+
"causal": lambda: test_mask(mask_mod=causal_mask),
31+
"alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=True),
32+
"sliding_window": lambda: test_mask(mask_mod=generate_sliding_window(window_size=1024)),
33+
"prefix_lm": lambda: test_mask(mask_mod=generate_prefix_lm_mask(prefix_length=1024)),
34+
"document": lambda: run_document_masking(max_seq_len=32768, num_docs=12),
35+
"softcap": lambda: test_mask(
36+
score_mod=generate_tanh_softcap(30, approx=False), skip_correctness=True
37+
),
38+
"softcap_approx": lambda: test_mask(
39+
score_mod=generate_tanh_softcap(30, approx=True), skip_correctness=True
40+
),
41+
}
42+
43+
2944
torch.set_default_device("cuda")
3045
torch.manual_seed(0)
3146

@@ -97,19 +112,15 @@ def test_mask(
97112
causal_fav2_flops = 0.5 * B * H * D * S * S
98113
flops = density * B * H * D * S * S
99114

100-
# Forward pass
101-
causal_fa2_time = do_bench(causal_fa2)
102-
sdpa_mask_time = do_bench(sdpa_mask)
103-
flex_ms = do_bench(flex_attention_call)
115+
times = []
116+
for attn in (causal_fa2, sdpa_mask, flex_attention_call):
117+
fwd_time = do_bench(attn)
118+
fwd_out = attn()
119+
bwd_time = do_bench(lambda: fwd_out.backward(gradOut, retain_graph=True)) # noqa: F821
120+
times.append((fwd_time, bwd_time))
104121

105-
# Backward pass
106-
causal_fa2_out = causal_fa2()
107-
sdpa_mask_out = sdpa_mask()
108-
flex_out = flex_attention_call()
109-
110-
causal_fa2_bw_time = do_bench(lambda: causal_fa2_out.backward(gradOut, retain_graph=True))
111-
sdpa_mask_bw_time = do_bench(lambda: sdpa_mask_out.backward(gradOut, retain_graph=True))
112-
flex_bw_ms = do_bench(lambda: flex_out.backward(gradOut, retain_graph=True))
122+
del fwd_out
123+
torch.cuda.empty_cache()
113124

114125
print_header(
115126
f"{score_mod.__name__ if score_mod is not None else mask_mod.__name__}".replace(
@@ -140,6 +151,12 @@ def test_mask(
140151
torch.testing.assert_close(flex, sdpa_mask, atol=1e-1, rtol=1e-2)
141152

142153
print("Correctness check passed ✅")
154+
155+
(
156+
(causal_fa2_time, causal_fa2_bw_time),
157+
(sdpa_mask_time, sdpa_mask_bw_time),
158+
(flex_ms, flex_bw_ms),
159+
) = times
143160
# Usage in your results formatting:
144161
results = [
145162
[
@@ -210,28 +227,16 @@ def main(examples: List[str] = ["all"]):
210227
Args:
211228
examples: List of examples to run. If "all" is specified, all examples will be run.
212229
"""
213-
available_examples = {
214-
"causal": lambda: test_mask(mask_mod=causal_mask),
215-
"alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=True),
216-
"sliding_window": lambda: test_mask(mask_mod=generate_sliding_window(window_size=1024)),
217-
"prefix_lm": lambda: test_mask(mask_mod=generate_prefix_lm_mask(prefix_length=1024)),
218-
"document": lambda: run_document_masking(max_seq_len=32768, num_docs=12),
219-
"softcap": lambda: test_mask(
220-
score_mod=generate_tanh_softcap(30, approx=False), skip_correctness=True
221-
),
222-
"softcap_approx": lambda: test_mask(
223-
score_mod=generate_tanh_softcap(30, approx=True), skip_correctness=True
224-
),
225-
}
226230

227231
if "all" in examples:
228-
ex_to_run = list(available_examples.keys())
232+
ex_to_run = list(AVAILABLE_EXAMPLES.keys())
229233
else:
230234
ex_to_run = examples
231235

232236
for ex in ex_to_run:
233-
if ex in available_examples:
234-
available_examples[ex]()
237+
if ex in AVAILABLE_EXAMPLES:
238+
AVAILABLE_EXAMPLES[ex]()
239+
torch.cuda.empty_cache()
235240
else:
236241
print(f"Warning: Unknown example key '{ex}'. Skipping.")
237242

@@ -248,8 +253,9 @@ def main(examples: List[str] = ["all"]):
248253
nargs="+",
249254
default=["all"],
250255
help="List of examples to run. Use space to separate multiple examples. "
251-
"Available options: causal, alibi, sliding_window, prefix_lm, "
252-
"document, softcap, softcap_approx, or 'all' to run all examples.",
256+
"Available options: "
257+
+ ", ".join(sorted(AVAILABLE_EXAMPLES.keys()))
258+
+ ", or 'all' to run all examples.",
253259
)
254260

255261
args = parser.parse_args()

0 commit comments

Comments
 (0)