Skip to content

Commit 8367985

Browse files
authored
[Transforms] Transform Args, Scheme, and Config (#321)
* add args, scheme, config Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix typo Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * use input/output side Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove llama spinquant Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove side, rename to config_groups Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * use enum Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * update docstring, remove unused code Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent eddd2a1 commit 8367985

File tree

7 files changed

+390
-0
lines changed

7 files changed

+390
-0
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# flake8: noqa
16+
# isort: skip_file
17+
18+
from .transform_args import *
19+
from .transform_scheme import *
20+
from .transform_config import *
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from enum import Enum
16+
from typing import Any, List
17+
18+
from pydantic import BaseModel, Field, field_validator
19+
20+
21+
__all__ = ["TransformArgs"]
22+
23+
24+
class TransformLocation(str, Enum):
25+
INPUT = "input"
26+
WEIGHT_INPUT = "weight_input"
27+
WEIGHT_OUTPUT = "weight_output"
28+
OUTPUT = "output"
29+
K_CACHE = "k_cache"
30+
Q_ATTN = "q_attn"
31+
32+
33+
class TransformArgs(BaseModel):
34+
"""
35+
Arguments which define *how* and where a transform should be applied to a model
36+
37+
:param targets: list of modules to apply transforms to
38+
:param location: where to apply transform on module, one of (`input`, `weight`,
39+
`output`, `k_cache`, `q_attn`)
40+
:param inverse: whether or not to apply the inverse of a transform
41+
:param ignore: any modules which should be ignored from the targets list
42+
"""
43+
44+
targets: List[str]
45+
location: TransformLocation
46+
inverse: bool = Field(default=False)
47+
ignore: List[str] = Field(default_factory=list)
48+
49+
@field_validator("targets", "ignore", mode="before")
50+
@classmethod
51+
def wrap_singleton(cls, value):
52+
if isinstance(value, str):
53+
return [value]
54+
return value
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Dict
16+
17+
from compressed_tensors.transform import TransformArgs, TransformScheme
18+
from pydantic import BaseModel
19+
20+
21+
__all__ = ["TransformConfig"]
22+
23+
24+
class TransformConfig(BaseModel):
25+
"""
26+
Configuration of transforms to be applied to a model. This config is to be
27+
serialized within a model's `config.json` file
28+
29+
:param config_groups: A dictionary of `TransformSchemes` that should be applied
30+
to a particular model. The keys can be any arbitrary string
31+
"""
32+
33+
config_groups: Dict[str, TransformScheme]
34+
35+
36+
# quip / quip sharp
37+
QUIP = TransformConfig(
38+
config_groups={
39+
"v": TransformScheme(
40+
type="hadamard",
41+
apply=[
42+
TransformArgs(
43+
targets=["Linear"],
44+
location="input", # non-mergable
45+
),
46+
TransformArgs(
47+
targets=["Linear"],
48+
location="weight_input",
49+
inverse=True,
50+
),
51+
],
52+
randomize_modules=True,
53+
),
54+
"u": TransformScheme(
55+
type="hadamard",
56+
apply=[
57+
TransformArgs(
58+
targets=["Linear"],
59+
location="weight_output",
60+
),
61+
TransformArgs(
62+
targets=["Linear"], location="output", inverse=True # non-mergable
63+
),
64+
],
65+
randomize_modules=True,
66+
),
67+
}
68+
)
69+
70+
71+
PRESET_CONFIGS = {
72+
"QUIP": QUIP,
73+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List
16+
17+
from compressed_tensors.transform import TransformArgs
18+
from pydantic import BaseModel, Field
19+
20+
21+
__all__ = ["TransformScheme"]
22+
23+
24+
class TransformScheme(BaseModel):
25+
"""
26+
Scheme used to parameterize a particular transform type and specify how and where it
27+
should be applied to the model
28+
29+
:param type: string indicating the particular transform type that should be created
30+
and applied. This should be one of the registered transform types
31+
(see `Transforms.registered_names()`)
32+
:param apply: list of TransformationArgs containing the information about the
33+
modules that should be targeted by the specified transform
34+
:param randomize_modules: True if unique transforms should be applied to each
35+
unique module targeted by `apply`, otherwise reuse transform weights where
36+
applicable
37+
:param requires_grad: True if weights include gradients for training
38+
"""
39+
40+
type: str
41+
apply: List[TransformArgs] = Field(default_factory=list)
42+
randomize_modules: bool = Field(default=False)
43+
requires_grad: bool = Field(default=False)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from compressed_tensors.transform import TransformArgs
16+
17+
18+
def test_basic():
19+
targets = ["Embedding"]
20+
location = "input"
21+
args = TransformArgs(targets=targets, location=location)
22+
23+
assert args.targets == targets
24+
assert args.location == location
25+
assert len(args.ignore) == 0
26+
27+
28+
def test_args_full():
29+
targets = ["Linear"]
30+
location = "weight_input"
31+
inverse = True
32+
ignore = ["model.layers.2"]
33+
34+
args = TransformArgs(
35+
targets=targets,
36+
location=location,
37+
inverse=inverse,
38+
ignore=ignore,
39+
)
40+
41+
args.targets = targets
42+
args.location == location
43+
args.inverse == inverse
44+
args.ignore == ignore
45+
46+
47+
def test_singleton_targets():
48+
target = "target"
49+
location = "input"
50+
ignore = "ignore"
51+
args = TransformArgs(targets=target, location=location, ignore=ignore)
52+
53+
assert args.targets == [target]
54+
assert args.location == location
55+
assert args.ignore == [ignore]
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import pytest
17+
from compressed_tensors.transform import TransformArgs, TransformConfig, TransformScheme
18+
19+
20+
@pytest.fixture
21+
def basic_transform_scheme():
22+
targets = ["Embedding"]
23+
location = "input"
24+
basic_args = TransformArgs(targets=targets, location=location)
25+
26+
return TransformScheme(
27+
type="hadamard",
28+
apply=[basic_args],
29+
)
30+
31+
32+
def test_basic(basic_transform_scheme):
33+
config = TransformConfig(
34+
config_groups={
35+
"transform_0": basic_transform_scheme,
36+
}
37+
)
38+
assert isinstance(config.config_groups.get("transform_0"), TransformScheme)
39+
40+
41+
def test_to_dict(basic_transform_scheme):
42+
config = TransformConfig(
43+
config_groups={
44+
"transform_0": basic_transform_scheme,
45+
}
46+
)
47+
config_dict = config.model_dump()
48+
assert "config_groups" in config_dict.keys()
49+
50+
51+
def test_multiple_groups():
52+
location = "weight_input"
53+
54+
targets_1 = ["model.layers.0.attn.v_proj"]
55+
linear_args_1 = TransformArgs(targets=targets_1, location=location)
56+
57+
targets_2 = ["model.layers.0.attn.q_proj"]
58+
linear_args_2 = TransformArgs(targets=targets_2, location=location)
59+
60+
scheme_1 = TransformScheme(
61+
type="hadamard",
62+
apply=[linear_args_1],
63+
)
64+
65+
scheme_2 = TransformScheme(
66+
type="hadamard",
67+
apply=[linear_args_2],
68+
)
69+
config = TransformConfig(
70+
config_groups={"transform_0": scheme_1, "transform_1": scheme_2}
71+
)
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from compressed_tensors.transform import TransformArgs, TransformScheme
16+
17+
18+
def test_basic_scheme():
19+
targets = ["Linear"]
20+
location = "input"
21+
basic_args = TransformArgs(targets=targets, location=location)
22+
23+
scheme = TransformScheme(
24+
type="hadamard",
25+
apply=[basic_args],
26+
)
27+
assert not scheme.randomize_modules
28+
assert scheme.type == "hadamard"
29+
assert len(scheme.apply) == 1
30+
assert isinstance(scheme.apply[0], TransformArgs)
31+
32+
33+
def test_multiple_groups_global():
34+
targets = ["Embedding"]
35+
location = "input"
36+
embedding_args = TransformArgs(targets=targets, location=location)
37+
38+
targets = ["Linear"]
39+
location = "weight_input"
40+
linear_args = TransformArgs(targets=targets, location=location)
41+
42+
# same transform applied to multiple groups
43+
scheme = TransformScheme(
44+
type="hadamard",
45+
apply=[embedding_args, linear_args],
46+
randomize_modules=True,
47+
)
48+
49+
assert scheme.randomize_modules
50+
assert scheme.type == "hadamard"
51+
assert len(scheme.apply) == 2
52+
assert isinstance(scheme.apply[0], TransformArgs)
53+
assert isinstance(scheme.apply[1], TransformArgs)
54+
55+
56+
def test_multiple_groups():
57+
apply = []
58+
location = "weight_output"
59+
60+
for i in range(20):
61+
targets = [f"model.layers.{i}.attn.v_proj", f"model.layers.{i}.attn.o_proj"]
62+
args = TransformArgs(targets=targets, location=location)
63+
apply.append(args)
64+
65+
# global is False, different hadamard transform applied to each group
66+
# same dimension/hidden dim
67+
scheme = TransformScheme(
68+
type="hadamard",
69+
apply=apply,
70+
)
71+
72+
assert not scheme.randomize_modules
73+
assert scheme.type == "hadamard"
74+
assert len(scheme.apply) == 20

0 commit comments

Comments
 (0)