Skip to content

Commit fadaaf8

Browse files
committed
transform arg support
1 parent 749420b commit fadaaf8

File tree

7 files changed

+404
-0
lines changed

7 files changed

+404
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from collections import defaultdict
16+
from enum import Enum
17+
from typing import Any, Dict, List, Optional, Union
18+
19+
from pydantic import BaseModel, Field, field_validator
20+
21+
22+
__all__ = ["TransformationArgs", "ModuleTarget"]
23+
24+
# TODO: we eventually want to target generic parameters but for now, this
25+
# is sufficient
26+
class ModuleTarget(str, Enum):
27+
"""
28+
Enum storing parameter or activation types being targeted by transforms
29+
in a particuilar module.
30+
"""
31+
32+
WEIGHT = "weight"
33+
INPUT_ACTIVATIONS = "input_activations"
34+
OUTPUT_ACTIVATIONS = "output_activations"
35+
36+
@classmethod
37+
def has_value(cls, value):
38+
return value in cls._value2member_map_
39+
40+
41+
class TransformationArgs(BaseModel):
42+
"""
43+
User-facing arguments used to define which modules and their specific
44+
parameters and/or activations should be targeted by a particular transform.
45+
46+
:param targets: list of layers to target
47+
:param module_targets: list of layer parameters and/or activations onto which the
48+
transform should be applied. The same transform will be applied for all
49+
module targets in the list.
50+
:param call_args: dictionary of args needed for the transform during runtime,
51+
beyond the input_tensor or transform
52+
:param ignore: any submodule which should be ignored from the targets list
53+
54+
"""
55+
56+
targets: List[str]
57+
module_targets: List[Union[ModuleTarget, str]]
58+
call_args: Optional[Dict[str, Any]] = defaultdict()
59+
ignore: Optional[List[str]] = Field(default_factory=list)
60+
61+
@field_validator("module_targets", mode="before")
62+
def validate_module_target(cls, value) -> List[ModuleTarget]:
63+
module_targets_list = []
64+
for v in value:
65+
assert ModuleTarget.has_value(v.lower())
66+
module_targets_list.append(v)
67+
68+
return module_targets_list
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Dict
16+
17+
from compressed_tensors.transforms.transform_scheme import TransformationScheme
18+
from pydantic import BaseModel
19+
20+
21+
__all__ = ["TransformationConfig"]
22+
23+
24+
class TransformationConfig(BaseModel):
25+
"""
26+
Configuration of transforms to be added within a model's config.json.
27+
28+
:param transform_groups: A dictionary of the different TransformationSchemes
29+
that should be applied to a particular model. The keys can be any
30+
arbitrary string and a TransformationScheme should be provided
31+
for each new transform type.
32+
"""
33+
34+
transform_groups: Dict[str, TransformationScheme]
35+
36+
def to_dict(self):
37+
return self.model_dump()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass
16+
from typing import Any, Callable, Dict
17+
18+
19+
__all__ = ["TransformData"]
20+
21+
22+
# TODO: adding for now but we may not need it during runtime depending on the
23+
# integration.
24+
@dataclass
25+
class TransformData:
26+
"""
27+
Data that is required during runtime in order to apply the transform.
28+
29+
Example:
30+
data={transform_name: {"apply": Callable, "call_args": Dict}})
31+
transform_data = TransformData(data=data)
32+
"""
33+
34+
data: Dict
35+
idx: int = 0
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Dict, List, Optional
16+
17+
from compressed_tensors.transforms import Transforms
18+
from compressed_tensors.transforms.transform_args import TransformationArgs
19+
from pydantic import BaseModel, field_validator
20+
21+
22+
__all__ = ["TransformationScheme"]
23+
24+
25+
class TransformationScheme(BaseModel):
26+
"""
27+
:param transform_type: string indicating the particular transform that
28+
should be created and applied. This should be one of the registered transforms
29+
i.e be in Transforms.registered_names()
30+
:param groups: includes TransformationArgs containing the information about the
31+
layers that should be targeted by the specified transform. By providing a list,
32+
users have the ability to apply the same transform type (with the same set
33+
of arguments) to different layers.
34+
:param transform_creation_args: arguments needed to initialize the transform, if
35+
any
36+
:param global_transform: whether an identical transform is applied to all the
37+
TransformationArgs in the groups list
38+
"""
39+
40+
transform_type: str
41+
groups: List[TransformationArgs]
42+
global_transform: bool = False
43+
transform_creation_args: Optional[Dict[str, Any]] = None
44+
45+
@field_validator("transform_type", mode="before")
46+
def validate_transform_type(cls, value) -> str:
47+
assert value in Transforms.registered_names()
48+
return value
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from collections import defaultdict
17+
18+
from compressed_tensors.transforms.transform_args import (
19+
ModuleTarget,
20+
TransformationArgs,
21+
)
22+
23+
24+
def test_transform_args_basic():
25+
targets = ["Embedding"]
26+
module_targets = [ModuleTarget.INPUT_ACTIVATIONS]
27+
basic_args = TransformationArgs(targets=targets, module_targets=module_targets)
28+
29+
assert basic_args.targets[0] == "Embedding"
30+
assert basic_args.module_targets[0] == ModuleTarget.INPUT_ACTIVATIONS
31+
assert isinstance(type(basic_args.call_args), type(defaultdict))
32+
assert len(basic_args.ignore) == 0
33+
34+
35+
def test_transform_args_full():
36+
targets = ["Linear"]
37+
module_targets = ["weight", "input_activations"]
38+
ignore = ["model.layers.2"]
39+
call_args = {"transpose": True}
40+
41+
full_args = TransformationArgs(
42+
targets=targets,
43+
module_targets=module_targets,
44+
call_args=call_args,
45+
ignore=ignore,
46+
)
47+
48+
full_args.targets = targets
49+
full_args.ignore == ignore
50+
full_args.module_targets[0] == ModuleTarget.WEIGHT.value
51+
full_args.module_targets[1] == ModuleTarget.INPUT_ACTIVATIONS.value
52+
assert full_args.call_args.get("transpose")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import pytest
17+
from compressed_tensors.transforms.transform_args import (
18+
ModuleTarget,
19+
TransformationArgs,
20+
)
21+
from compressed_tensors.transforms.transform_config import TransformationConfig
22+
from compressed_tensors.transforms.transform_scheme import TransformationScheme
23+
24+
25+
@pytest.fixture
26+
def basic_transform_scheme():
27+
targets = ["Embedding"]
28+
module_targets = [ModuleTarget.INPUT_ACTIVATIONS]
29+
basic_args = TransformationArgs(targets=targets, module_targets=module_targets)
30+
31+
scheme = TransformationScheme(
32+
transform_type="hadamard",
33+
groups=[basic_args],
34+
transform_creation_args={"size": 1024},
35+
)
36+
return scheme
37+
38+
39+
def test_basic(basic_transform_scheme):
40+
config = TransformationConfig(
41+
transform_groups={
42+
"transform_0": basic_transform_scheme,
43+
}
44+
)
45+
assert isinstance(config.transform_groups.get("transform_0"), TransformationScheme)
46+
47+
48+
def test_to_dict(basic_transform_scheme):
49+
config = TransformationConfig(
50+
transform_groups={
51+
"transform_0": basic_transform_scheme,
52+
}
53+
)
54+
config_dict = config.to_dict()
55+
assert "transform_groups" in config_dict.keys()
56+
57+
58+
def test_multiple_groups():
59+
module_targets = [ModuleTarget.WEIGHT]
60+
61+
targets_1 = ["model.layers.0.attn.v_proj"]
62+
linear_args_1 = TransformationArgs(targets=targets_1, module_targets=module_targets)
63+
64+
targets_2 = ["model.layers.0.attn.q_proj"]
65+
linear_args_2 = TransformationArgs(targets=targets_2, module_targets=module_targets)
66+
67+
scheme_1 = TransformationScheme(
68+
transform_type="hadamard",
69+
groups=[linear_args_1],
70+
transform_creation_args={"size": 1024},
71+
)
72+
73+
scheme_2 = TransformationScheme(
74+
transform_type="hadamard",
75+
groups=[linear_args_2],
76+
transform_creation_args={"size": 256},
77+
)
78+
config = TransformationConfig(
79+
transform_groups={"transform_0": scheme_1, "transform_1": scheme_2}
80+
)

0 commit comments

Comments
 (0)