|
560 | 560 | "cell_type": "markdown",
|
561 | 561 | "metadata": {},
|
562 | 562 | "source": [
|
563 |
| - "### NATTEN MASKING\n", |
| 563 | + "### Stand-Alone Self-Attention Masking\n", |
564 | 564 | "\n",
|
565 | 565 | "In this case, imagine that we have a 2D image of size (H x W) flattened into a\n",
|
566 | 566 | "sequence of tokens. We only want to attend to tokens within 8 `pixels`, but\n",
|
567 | 567 | "from a 2D perspective.\n",
|
568 | 568 | "\n",
|
569 | 569 | "We can implement this mask_mod by first translating the 1D position into 2D coordinates. Then, we can simply check if the distance of both coordinates is within the window.\n",
|
570 | 570 | "\n",
|
571 |
| - "For more details check the paper's github repository [NATTEN](https://github.com/SHI-Labs/NATTEN) " |
| 571 | + "For more details check the paper, [Stand-Alone Self-Attention in Vision Models](https://arxiv.org/abs/1906.05909)" |
572 | 572 | ]
|
573 | 573 | },
|
574 | 574 | {
|
|
591 | 591 | " return idx // W, idx % W\n",
|
592 | 592 | "\n",
|
593 | 593 | "\n",
|
594 |
| - "def natten_mask(b, h, q_idx, kv_idx):\n", |
| 594 | + "def sasa_mask(b, h, q_idx, kv_idx):\n", |
595 | 595 | " q_x, q_y = get_x_y(q_idx)\n",
|
596 | 596 | " kv_x, kv_y = get_x_y(kv_idx)\n",
|
597 | 597 | " horizontal_mask = (q_x - kv_x).abs() <= WINDOW\n",
|
598 | 598 | " vertical_mask = (q_y - kv_y).abs() <= WINDOW\n",
|
599 | 599 | " return horizontal_mask & vertical_mask\n",
|
600 | 600 | "\n",
|
601 | 601 | "\n",
|
| 602 | + "test_mask(mask_mod=sasa_mask)" |
| 603 | + ] |
| 604 | + }, |
| 605 | + { |
| 606 | + "cell_type": "markdown", |
| 607 | + "metadata": {}, |
| 608 | + "source": [ |
| 609 | + "### NATTEN Masking\n", |
| 610 | + "\n", |
| 611 | + "Consider a 2D image of size (H x W) flattened into a sequence of tokens.\n", |
| 612 | + "Queries attend to keys in a fixed kernel area (K_H x K_W), centered where possible\n", |
| 613 | + "on the query, whilst staying within the canvas and always including the query.\n", |
| 614 | + "\n", |
| 615 | + "This is similar to SASA, except with extra handling to keep the kernel inside the canvas,\n", |
| 616 | + "ensuring that all queries attend to a fixed number of keys. \n", |
| 617 | + "Keys compare their position to the kernel center, not the query. The kernel center attempts\n", |
| 618 | + "to follow the query position, but is clamped to stay a fixed distance (its half-length) away\n", |
| 619 | + "from the canvas edge.\n", |
| 620 | + "\n", |
| 621 | + "See the [NATTEN repository](https://github.com/SHI-Labs/NATTEN) for more information. \n", |
| 622 | + "_Note: a more complete implementation of NATTEN would include support for kernel dilation._ \n", |
| 623 | + "_The NATTEN unfused kernel also has features like the ability to cross-attend to register tokens._\n", |
| 624 | + "_This capability is possible to express in Flex Attention but not attempted here._" |
| 625 | + ] |
| 626 | + }, |
| 627 | + { |
| 628 | + "cell_type": "code", |
| 629 | + "execution_count": null, |
| 630 | + "metadata": {}, |
| 631 | + "outputs": [], |
| 632 | + "source": [ |
| 633 | + "H = 128\n", |
| 634 | + "W = 128\n", |
| 635 | + "K_H = 7\n", |
| 636 | + "K_W = 7\n", |
| 637 | + "\n", |
| 638 | + "\n", |
| 639 | + "def get_x_y(idx):\n", |
| 640 | + " return idx // W, idx % W\n", |
| 641 | + "\n", |
| 642 | + "\n", |
| 643 | + "def natten_mask(\n", |
| 644 | + " b,\n", |
| 645 | + " h,\n", |
| 646 | + " q_idx,\n", |
| 647 | + " kv_idx,\n", |
| 648 | + "):\n", |
| 649 | + " q_x, q_y = get_x_y(q_idx)\n", |
| 650 | + " kv_x, kv_y = get_x_y(kv_idx)\n", |
| 651 | + " # kernel nominally attempts to center itself on the query, but kernel center\n", |
| 652 | + " # is clamped to a fixed distance (kernel half-length) from the canvas edge\n", |
| 653 | + " kernel_x = q_x.clamp(K_W // 2, (W - 1) - K_W // 2)\n", |
| 654 | + " kernel_y = q_y.clamp(K_H // 2, (H - 1) - K_H // 2)\n", |
| 655 | + " hori_mask = (kernel_x - kv_x).abs() <= K_W // 2\n", |
| 656 | + " vert_mask = (kernel_y - kv_y).abs() <= K_H // 2\n", |
| 657 | + " return hori_mask & vert_mask\n", |
| 658 | + "\n", |
| 659 | + "\n", |
602 | 660 | "test_mask(mask_mod=natten_mask)"
|
603 | 661 | ]
|
604 | 662 | },
|
|
0 commit comments