From be9c92fa16a1e3565e94d9181acdc0bcbad5640d Mon Sep 17 00:00:00 2001 From: Yichen Yan Date: Mon, 13 Jan 2025 12:26:55 +0800 Subject: [PATCH 1/6] Check correctness for `score_mod` implementations --- examples/benchmark.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/examples/benchmark.py b/examples/benchmark.py index 50debe2..518dd3c 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -28,16 +28,13 @@ AVAILABLE_EXAMPLES = { "causal": lambda: test_mask(mask_mod=causal_mask), + "causal_score": lambda: test_mask(score_mod=lambda score, b, h, q_idx, kv_idx: torch.where(causal_mask(b, h, q_idx, kv_idx), score, torch.finfo(score.dtype))), "alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=True), "sliding_window": lambda: test_mask(mask_mod=generate_sliding_window(window_size=1024)), "prefix_lm": lambda: test_mask(mask_mod=generate_prefix_lm_mask(prefix_length=1024)), "document": lambda: run_document_masking(max_seq_len=32768, num_docs=12), - "softcap": lambda: test_mask( - score_mod=generate_tanh_softcap(30, approx=False), skip_correctness=True - ), - "softcap_approx": lambda: test_mask( - score_mod=generate_tanh_softcap(30, approx=True), skip_correctness=True - ), + "softcap": lambda: test_mask(score_mod=generate_tanh_softcap(30, approx=False)), + "softcap_approx": lambda: test_mask(score_mod=generate_tanh_softcap(30, approx=True)), } @@ -93,6 +90,14 @@ def test_mask( block_mask = None sdpa_mask_fn = mask_mod if mask_mod is not None else score_mod mask = create_mask(sdpa_mask_fn, 1, 1, S, S, device=device) + if score_mod: + mask = torch.where(mask, score_mod( + torch.zeros_like(mask, dtype=data_type), + torch.tensor([1], dtype=data_type), + torch.tensor([1], dtype=data_type), + torch.tensor([[s for i in range(S)] for s in range(S)], dtype=torch.int64), + torch.tensor([[i for i in range(S)] for s in range(S)], dtype=torch.int64), + ), torch.finfo(data_type).min) qkv = [ torch.randn(B, H, S, D, device=device, dtype=data_type, requires_grad=True) @@ -121,6 +126,11 @@ def test_mask( del fwd_out torch.cuda.empty_cache() + ( + (causal_fa2_time, causal_fa2_bw_time), + (sdpa_mask_time, sdpa_mask_bw_time), + (flex_ms, flex_bw_ms), + ) = times print_header( f"{score_mod.__name__ if score_mod is not None else mask_mod.__name__}".replace( @@ -152,11 +162,6 @@ def test_mask( print("Correctness check passed ✅") - ( - (causal_fa2_time, causal_fa2_bw_time), - (sdpa_mask_time, sdpa_mask_bw_time), - (flex_ms, flex_bw_ms), - ) = times # Usage in your results formatting: results = [ [ From e0c9613c46072bf289c19ccf711ba67eef35d2d2 Mon Sep 17 00:00:00 2001 From: Yichen Yan Date: Mon, 13 Jan 2025 14:28:36 +0800 Subject: [PATCH 2/6] Validate alibi (with local patch to return score) --- examples/benchmark.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/examples/benchmark.py b/examples/benchmark.py index 518dd3c..5aac7b4 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -29,7 +29,7 @@ AVAILABLE_EXAMPLES = { "causal": lambda: test_mask(mask_mod=causal_mask), "causal_score": lambda: test_mask(score_mod=lambda score, b, h, q_idx, kv_idx: torch.where(causal_mask(b, h, q_idx, kv_idx), score, torch.finfo(score.dtype))), - "alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=True), + "alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=False), "sliding_window": lambda: test_mask(mask_mod=generate_sliding_window(window_size=1024)), "prefix_lm": lambda: test_mask(mask_mod=generate_prefix_lm_mask(prefix_length=1024)), "document": lambda: run_document_masking(max_seq_len=32768, num_docs=12), @@ -88,16 +88,14 @@ def test_mask( block_mask = create_block_mask_cached(mask_mod, 1, 1, S, S, device=device) else: block_mask = None - sdpa_mask_fn = mask_mod if mask_mod is not None else score_mod - mask = create_mask(sdpa_mask_fn, 1, 1, S, S, device=device) - if score_mod: - mask = torch.where(mask, score_mod( - torch.zeros_like(mask, dtype=data_type), - torch.tensor([1], dtype=data_type), - torch.tensor([1], dtype=data_type), - torch.tensor([[s for i in range(S)] for s in range(S)], dtype=torch.int64), - torch.tensor([[i for i in range(S)] for s in range(S)], dtype=torch.int64), - ), torch.finfo(data_type).min) + mask = create_mask(mask_mod, 1, H, S, S, device=device) if mask_mod else None + bias = create_mask(score_mod, 1, H, S, S, device=device) if score_mod else None + if bias is not None: + bias = bias.to(dtype=data_type) + if mask: + mask = bias.where(mask, torch.finfo(data_type).min) + else: + assert mask is not None qkv = [ torch.randn(B, H, S, D, device=device, dtype=data_type, requires_grad=True) From a98fdea2127ed417ae02eac94b13679ba0f44460 Mon Sep 17 00:00:00 2001 From: Yichen Yan Date: Mon, 13 Jan 2025 14:30:28 +0800 Subject: [PATCH 3/6] fmt and fix --- examples/benchmark.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/benchmark.py b/examples/benchmark.py index 5aac7b4..0054809 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -28,7 +28,11 @@ AVAILABLE_EXAMPLES = { "causal": lambda: test_mask(mask_mod=causal_mask), - "causal_score": lambda: test_mask(score_mod=lambda score, b, h, q_idx, kv_idx: torch.where(causal_mask(b, h, q_idx, kv_idx), score, torch.finfo(score.dtype))), + "causal_score": lambda: test_mask( + score_mod=lambda score, *args: torch.where( + causal_mask(*args), score, torch.finfo(score.dtype).min + ) + ), "alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=False), "sliding_window": lambda: test_mask(mask_mod=generate_sliding_window(window_size=1024)), "prefix_lm": lambda: test_mask(mask_mod=generate_prefix_lm_mask(prefix_length=1024)), From 12b5e161cfdeacfd3ada9cc635737765fcb81c59 Mon Sep 17 00:00:00 2001 From: Yichen Yan Date: Mon, 13 Jan 2025 14:32:19 +0800 Subject: [PATCH 4/6] fix --- examples/benchmark.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/benchmark.py b/examples/benchmark.py index 0054809..2d58d6f 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -97,7 +97,8 @@ def test_mask( if bias is not None: bias = bias.to(dtype=data_type) if mask: - mask = bias.where(mask, torch.finfo(data_type).min) + bias = bias.where(mask, torch.finfo(data_type).min) + mask = bias else: assert mask is not None From 1774355b822485c73dbb3c2bfa4ac0924bf47e1b Mon Sep 17 00:00:00 2001 From: Yichen Yan Date: Mon, 13 Jan 2025 14:47:00 +0800 Subject: [PATCH 5/6] format --- attn_gym/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/attn_gym/utils.py b/attn_gym/utils.py index 69cae71..8b6dc97 100644 --- a/attn_gym/utils.py +++ b/attn_gym/utils.py @@ -100,9 +100,9 @@ def visualize_attention_scores( Returns: None """ - assert ( - score_mod is not None or mask_mod is not None - ), "Must provide either score_mod or mask_mod" + assert score_mod is not None or mask_mod is not None, ( + "Must provide either score_mod or mask_mod" + ) query = query[batch_idx, head_idx, :, :] key = key[batch_idx, head_idx, :, :] scores_viz = create_score_mod( From 29f43391be8c952e86307c6cb1682022abf7de00 Mon Sep 17 00:00:00 2001 From: Yichen Yan Date: Mon, 13 Jan 2025 15:14:38 +0800 Subject: [PATCH 6/6] `score_mod` do not support `*args` --- examples/benchmark.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/benchmark.py b/examples/benchmark.py index 2d58d6f..690d524 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -26,13 +26,13 @@ from attn_gym.mods import generate_alibi_bias, generate_tanh_softcap +def _causal_score(score, b, h, q_idx, kv_idx): + return causal_mask(b, h, q_idx, kv_idx).where(score, torch.finfo(score.dtype).min) + + AVAILABLE_EXAMPLES = { "causal": lambda: test_mask(mask_mod=causal_mask), - "causal_score": lambda: test_mask( - score_mod=lambda score, *args: torch.where( - causal_mask(*args), score, torch.finfo(score.dtype).min - ) - ), + "causal_score": lambda: test_mask(score_mod=_causal_score), "alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=False), "sliding_window": lambda: test_mask(mask_mod=generate_sliding_window(window_size=1024)), "prefix_lm": lambda: test_mask(mask_mod=generate_prefix_lm_mask(prefix_length=1024)),