Skip to content

Commit 290c1e7

Browse files
authored
fix alibi bias (#86)
1 parent bbf437e commit 290c1e7

File tree

3 files changed

+29
-5
lines changed

3 files changed

+29
-5
lines changed

attn_gym/mods/alibi.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ def generate_alibi_bias(H: int) -> _score_mod_signature:
1616

1717
def alibi_mod(score, b, h, q_idx, kv_idx):
1818
scale = torch.exp2(-((h + 1) * 8.0 / H))
19-
bias = (q_idx - kv_idx) * scale
19+
bias = (kv_idx - q_idx) * scale
2020
return score + bias
2121

2222
return alibi_mod
2323

2424

25-
def main(device: str = "cpu"):
25+
def main(device: str = "cpu", causal: bool = True):
2626
"""Visualize the attention scores alibi bias score mod.
2727
2828
Args:
@@ -40,8 +40,16 @@ def make_tensor():
4040

4141
alibi_score_mod = generate_alibi_bias(H)
4242

43+
def causal_mask(b, h, q_idx, kv_idx):
44+
return q_idx >= kv_idx
45+
4346
visualize_attention_scores(
44-
query, key, score_mod=alibi_score_mod, device=device, name="alibi_score_mod"
47+
query,
48+
key,
49+
score_mod=alibi_score_mod,
50+
mask_mod=causal_mask if causal else None,
51+
device=device,
52+
name=f"alibi_score_mod_{'causal' if causal else 'non-causal'}",
4553
)
4654

4755

attn_gym/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,27 @@ def visualize_attention_scores(
115115
batch_idx=batch_idx,
116116
head_idx=head_idx,
117117
)
118+
# If both score_mod and mask_mod are provided, apply both
119+
if score_mod is not None and mask_mod is not None:
120+
mask_viz = create_score_mod(
121+
query,
122+
key,
123+
score_mod=None,
124+
mask_mod=mask_mod,
125+
scale=scale,
126+
device=device,
127+
batch_idx=batch_idx,
128+
head_idx=head_idx,
129+
)
130+
# Apply mask by setting masked positions to -inf
131+
scores_viz = torch.where(mask_viz == 0, float("-inf"), scores_viz)
118132

119133
suffix_title = f"Batch {batch_idx}, Head {head_idx}" if batch_idx != 0 or head_idx != 0 else ""
120134

121135
fig, ax = plt.subplots(figsize=(12, 10))
122136
color = "viridis" if score_mod is not None else "cividis"
137+
if score_mod is not None and mask_mod is not None:
138+
color = "plasma"
123139
im = ax.imshow(scores_viz.cpu().detach()[0, 0, :, :], aspect="auto", cmap=color)
124140
fig.colorbar(im)
125141

examples/flex_attn.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -700,13 +700,13 @@
700700
"\n",
701701
"\n",
702702
"def alibi_and_causal_closure(score, b, h, q_idx, kv_idx):\n",
703-
" bias = alibi_bias[h] * (q_idx - kv_idx)\n",
703+
" bias = alibi_bias[h] * (kv_idx - q_idx)\n",
704704
" return score + bias\n",
705705
"\n",
706706
"\n",
707707
"def alibi_and_causal_functional(score, b, h, q_idx, kv_idx):\n",
708708
" scale = torch.exp2(-((h + 1) * 8.0 / H))\n",
709-
" bias = (q_idx - kv_idx) * scale\n",
709+
" bias = (kv_idx - q_idx) * scale\n",
710710
" return score + bias\n",
711711
"\n",
712712
"\n",

0 commit comments

Comments
 (0)