Skip to content

Commit 0e0b44d

Browse files
authored
Fix minor typo in example (#57)
1 parent f7c93de commit 0e0b44d

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

examples/flex_attn.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@
267267
"outputs": [],
268268
"source": [
269269
"def checkerboard(score, batch, head, token_q, token_kv):\n",
270-
" score = torch.where(torch.abs(token_kv - token_q) % 1 == 0, score * 0.5, score)\n",
270+
" score = torch.where(torch.abs(token_kv - token_q) % 2 == 1, score * 0.5, score)\n",
271271
" score = torch.where(torch.abs(token_kv - token_q) % 2 == 0, score * 2.0, score)\n",
272272
" return score\n",
273273
"\n",
@@ -316,7 +316,7 @@
316316
"The implementation using a score_mod:\n",
317317
"```Python\n",
318318
"def causal_bias(score, b, h, q_idx, kv_idx):\n",
319-
" return torch.where(q >= kv_idx, score, -float(\"inf\"))\n",
319+
" return torch.where(q_idx >= kv_idx, score, -float(\"inf\"))\n",
320320
"```\n",
321321
"\n",
322322
"Whenever you are writing a score_mod function that passes through the original score for some elements and sets others to -inf, you should likely be using a mask mod.\n",
@@ -326,7 +326,7 @@
326326
"```Python\n",
327327
"The implementation using a mask_mod:\n",
328328
"def casual_mask(b,h,q_idx, kv_idx):\n",
329-
" return q >= kv_idx\n",
329+
" return q_idx >= kv_idx\n",
330330
"```\n",
331331
"As you can see they look very similar, both return scalar tensors. The key differences\n",
332332
"1. mask_mods return boolean tensors where `True` indicates this score should be calculated, and `False` indicates we that we want to mask out this score\n",

0 commit comments

Comments
 (0)