Skip to content

Commit 31ac8e9

Browse files
docstrings
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent d0e5bc5 commit 31ac8e9

File tree

3 files changed

+43
-0
lines changed

3 files changed

+43
-0
lines changed

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,20 @@ class SpinQuantModifier(Modifier, use_enum_values=True):
4141
existing weights and therefore do not induce runtime cost. R3 and R4 are "online"
4242
rotations, meaning that they require additional computation at runtime.
4343
44+
Lifecycle:
45+
- on_initialize
46+
- infer SpinQuantMappings & NormMappings
47+
- as needed, create transform schemes for R1, R2, R3, & R4
48+
- on_start
49+
- normalize embeddings
50+
- fuse norm layers into subsequent Linear layers
51+
- apply TransformConfig
52+
- fuse transforms into weights for mergeable transforms
53+
- add hooks for online transforms
54+
- on sequential epoch end
55+
- on_end
56+
- on_finalize
57+
4458
:param rotations: A list containing the names of rotations to apply to the model.
4559
Possible rotations include R1, R2, R3, and R4
4660
:param transform_type: The type of transform to apply to the model.

src/llmcompressor/modifiers/transform/spinquant/mappings.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,25 @@
88

99

1010
class SpinQuantMapping(BaseModel):
11+
"""
12+
SpinQuant needs to know the entire architecture of the model,
13+
as R1, R2, R3, and R4 rotations need to be applied to specific
14+
layers (https://arxiv.org/pdf/2405.16406 Fig. 1).
15+
16+
:param embedding: name or regex of embedding layer
17+
:param attn_q: name or regex of q_proj layer in attention block
18+
:param attn_k: name or regex of k_proj layer in attention block
19+
:param attn_v: name or regex of v_proj layer in attention block
20+
:param attn_o: name or regex of o_proj layer in attention block
21+
:param attn_head_dim: head_dim of the attention module, needed
22+
because R2 needs to be applied "head-wisely" to v_proj and
23+
o_proj
24+
:param mlp_in: list of names or regexes for the mlp blocks that
25+
receive the input to the MLP block, usually up_proj and gate_proj
26+
:param mlp_out: list of names or regexes for the mlp blocks that
27+
consitute the output of the MLP block, usually down_proj
28+
"""
29+
1130
embedding: str
1231

1332
attn_q: str

src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,16 @@
88

99

1010
class NormMapping(BaseModel):
11+
"""
12+
SpinQuant needs to know where every norm layer exists in the model,
13+
as well as all the subsequent Linear layers the norm passes into.
14+
This is because the norm layer weights need to normalized before
15+
transforms can be fused into Linear layers.
16+
17+
:param norm: name or regex that matches norm layer in model
18+
:param linears: list of names or regexes of Linear layers that
19+
receive input from norm.
20+
"""
1121
norm: str
1222
linears: List[str]
1323

0 commit comments

Comments
 (0)