|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | from enum import Enum
|
16 |
| -from typing import Any, List |
| 16 | +from typing import List |
17 | 17 |
|
18 | 18 | from pydantic import BaseModel, Field, field_validator
|
19 | 19 |
|
20 | 20 |
|
21 |
| -__all__ = ["TransformLocation", "TransformArgs"] |
| 21 | +__all__ = ["TransformArgs", "TransformLocation"] |
22 | 22 |
|
23 | 23 |
|
24 | 24 | class TransformLocation(str, Enum):
|
| 25 | + """ |
| 26 | + Enum representing which parameters/activations a transform weight should be applied |
| 27 | + to on a given module. |
| 28 | +
|
| 29 | + | -------------------------------------------------------------------------------------------------------- | # noqa: E501 |
| 30 | + | Name | Runtime | Values | Locations Where Inverse Could Be Applied | # noqa: E501 |
| 31 | + | --------------- | ----------- | ------------- | -------------------------------------------------------- | # noqa: E501 |
| 32 | + | `INPUT` | online | activations | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.WEIGHT_INPUT` | # noqa: E501 |
| 33 | + | `WEIGHT_INPUT` | offline | weight | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.INPUT` | # noqa: E501 |
| 34 | + | `WEIGHT_OUTPUT` | offline | weight | `this.OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501 |
| 35 | + | `OUTPUT` | online | activations | `this.WEIGHT_OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501 |
| 36 | + | `K_CACHE` | online | key_values | `q_proj.Q_ATTN` | # noqa: E501 |
| 37 | + | `Q_ATTN` | online | query_values | `k_proj.K_CACHE` | # noqa: E501 |
| 38 | + | -------------------------------------------------------------------------------------------------------- | # noqa: E501 |
| 39 | + """ |
| 40 | + |
25 | 41 | INPUT = "input"
|
26 | 42 | WEIGHT_INPUT = "weight_input"
|
27 | 43 | WEIGHT_OUTPUT = "weight_output"
|
|
0 commit comments