|
267 | 267 | "outputs": [],
|
268 | 268 | "source": [
|
269 | 269 | "def checkerboard(score, batch, head, token_q, token_kv):\n",
|
270 |
| - " score = torch.where(torch.abs(token_kv - token_q) % 1 == 0, score * 0.5, score)\n", |
| 270 | + " score = torch.where(torch.abs(token_kv - token_q) % 2 == 1, score * 0.5, score)\n", |
271 | 271 | " score = torch.where(torch.abs(token_kv - token_q) % 2 == 0, score * 2.0, score)\n",
|
272 | 272 | " return score\n",
|
273 | 273 | "\n",
|
|
316 | 316 | "The implementation using a score_mod:\n",
|
317 | 317 | "```Python\n",
|
318 | 318 | "def causal_bias(score, b, h, q_idx, kv_idx):\n",
|
319 |
| - " return torch.where(q >= kv_idx, score, -float(\"inf\"))\n", |
| 319 | + " return torch.where(q_idx >= kv_idx, score, -float(\"inf\"))\n", |
320 | 320 | "```\n",
|
321 | 321 | "\n",
|
322 | 322 | "Whenever you are writing a score_mod function that passes through the original score for some elements and sets others to -inf, you should likely be using a mask mod.\n",
|
|
326 | 326 | "```Python\n",
|
327 | 327 | "The implementation using a mask_mod:\n",
|
328 | 328 | "def casual_mask(b,h,q_idx, kv_idx):\n",
|
329 |
| - " return q >= kv_idx\n", |
| 329 | + " return q_idx >= kv_idx\n", |
330 | 330 | "```\n",
|
331 | 331 | "As you can see they look very similar, both return scalar tensors. The key differences\n",
|
332 | 332 | "1. mask_mods return boolean tensors where `True` indicates this score should be calculated, and `False` indicates we that we want to mask out this score\n",
|
|
0 commit comments