Skip to content

Commit 2f5b1c8

Browse files
kylesayrsbrian-dellabetta
authored andcommitted
use random-hadamard, add correctness tests
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent ba617db commit 2f5b1c8

File tree

5 files changed

+15
-19
lines changed

5 files changed

+15
-19
lines changed

examples/transform/llama3_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def tokenize(sample):
5858
# * quantize the weights to 4 bit with GPTQ with a group size 128
5959
recipe = [
6060
TransformModifier(),
61-
GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"])
61+
GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
6262
]
6363

6464
# Apply algorithms.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# flake8: noqa
22

3-
from .transform import TransformModifier
3+
from .transform import TransformModifier

src/llmcompressor/modifiers/transform/template/quip.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from compressed_tensors.transform import TransformArgs, TransformScheme, TransformConfig
2-
1+
from compressed_tensors.transform import TransformArgs, TransformConfig, TransformScheme
32

43
QUIP = TransformConfig(
54
config_groups={
65
"v": TransformScheme(
7-
type="hadamard",
6+
type="random-hadamard",
87
apply=[
98
TransformArgs(
109
targets=["Linear"],
@@ -21,7 +20,7 @@
2120
randomize=True,
2221
),
2322
"u": TransformScheme(
24-
type="hadamard",
23+
type="random-hadamard",
2524
apply=[
2625
TransformArgs(
2726
targets=["Linear"],
@@ -32,10 +31,10 @@
3231
targets=["Linear"],
3332
location="output", # non-mergable
3433
inverse=True,
35-
ignore="lm_head"
34+
ignore="lm_head",
3635
),
3736
],
3837
randomize=True,
3938
),
4039
}
41-
)
40+
)

src/llmcompressor/modifiers/transform/template/spinquant.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from compressed_tensors.transform import TransformArgs, TransformScheme, TransformConfig
2-
1+
from compressed_tensors.transform import TransformArgs, TransformConfig, TransformScheme
32

43
LLAMA_SPINQUANT = TransformConfig(
54
transform_groups={
@@ -62,4 +61,4 @@
6261
],
6362
),
6463
}
65-
)
64+
)
Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from typing import Dict, Optional
22

3+
from compressed_tensors.transform import TransformScheme, apply_transform_config
4+
35
from llmcompressor.core import State
46
from llmcompressor.modifiers import Modifier
57

6-
from compressed_tensors.transform import TransformConfig, TransformScheme, TransformFactory
7-
88
from .template.quip import QUIP
99

10+
1011
class TransformModifier(Modifier):
1112
preset_config: Optional[str] = None
1213
config_groups: Optional[Dict[str, TransformScheme]] = None
@@ -18,11 +19,8 @@ def on_initialize(self, state: State, **kwargs):
1819
# import config template and customize to model
1920
pass
2021

21-
22-
#config = TransformConfig(config_groups=self.config_groups)
22+
# config = TransformConfig(config_groups=self.config_groups)
2323
config = QUIP
2424

25-
# TODO: use CT-provided apply_transform_config
26-
for name, scheme in config.config_groups.items():
27-
factory = TransformFactory.from_scheme(scheme, name=name)
28-
factory.apply_to_model(state.model)
25+
apply_transform_config(state.model, config)
26+
breakpoint()

0 commit comments

Comments
 (0)