Skip to content

Commit f7e078f

Browse files
authored
[Transform] Hadamard Permutations (#329)
* add utilities Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add additional tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add utils and tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * Implement transform factories Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add permutations Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add delete_offload_module Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * key inverses by weight Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * standardize random hadamard Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * prepend input hooks Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * apply sqrt division first Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * use divided hadamards Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix typo Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add random option Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * use random seeds, rename matrix multiply Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add deterministic generation to random matrix Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix perm math Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * update docstrings Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * update docstrings Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * cleanup Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * cleanup 2 Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * make seed optional Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove iterable check and missing return value Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * Remove unrelated changes * simplify code Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix style Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 98a0cd7 commit f7e078f

File tree

4 files changed

+28
-17
lines changed

4 files changed

+28
-17
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional
15+
from typing import Optional, Union
1616

1717
import torch
1818
from compressed_tensors.transform import TransformArgs, TransformScheme
@@ -41,6 +41,7 @@ class HadamardFactory(TransformFactory):
4141
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
4242
super().__init__(name, scheme, seed)
4343
self.weights = ParameterizedDefaultDict(self._create_weight)
44+
self.perms = ParameterizedDefaultDict(self._create_permutation)
4445

4546
def create_transform(self, module: Module, args: TransformArgs):
4647
"""
@@ -56,24 +57,35 @@ def create_transform(self, module: Module, args: TransformArgs):
5657
device = get_offloaded_device(module)
5758

5859
weight = self.weights[size, dtype, device]
59-
return HadamardTransform(weight, args)
60+
perm = self.perms[weight] if self.scheme.randomize else None
61+
return HadamardTransform(weight, perm, args)
6062

6163
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
6264
data = deterministic_hadamard_matrix(size, dtype, device)
6365
data = data.to(dtype=dtype, device=device)
6466
return Parameter(data, requires_grad=self.scheme.requires_grad)
6567

68+
def _create_permutation(self, weight: Parameter) -> Parameter:
69+
data = torch.randperm(weight.size(0), generator=self.generator)
70+
return Parameter(data, requires_grad=False)
71+
6672

6773
class HadamardTransform(TransformBase):
68-
def __init__(self, weight: Parameter, args: TransformArgs):
74+
def __init__(
75+
self, weight: Parameter, perm: Union[Parameter, None], args: TransformArgs
76+
):
6977
super().__init__()
7078
self.weight = weight
79+
self.perm = perm
7180
self.args = args
7281

7382
def forward(self, value: Tensor) -> Tensor:
74-
if not self.args.inverse:
75-
weight = self.weight
76-
else:
77-
weight = self.weight.T
83+
weight = self.weight
84+
85+
if self.perm is not None:
86+
weight = weight[self.perm][:, self.perm]
87+
88+
if self.args.inverse:
89+
weight = weight.T
7890

7991
return apply_transform_weight(weight, value, self.args.location)

src/compressed_tensors/transform/transform_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class TransformConfig(BaseModel):
4949
inverse=True,
5050
),
5151
],
52-
randomize_modules=True,
52+
randomize=True,
5353
),
5454
"u": TransformScheme(
5555
type="hadamard",
@@ -62,7 +62,7 @@ class TransformConfig(BaseModel):
6262
targets=["Linear"], location="output", inverse=True # non-mergable
6363
),
6464
],
65-
randomize_modules=True,
65+
randomize=True,
6666
),
6767
}
6868
)

src/compressed_tensors/transform/transform_scheme.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,12 @@ class TransformScheme(BaseModel):
3131
(see `Transforms.registered_names()`)
3232
:param apply: list of TransformationArgs containing the information about the
3333
modules that should be targeted by the specified transform
34-
:param randomize_modules: True if unique transforms should be applied to each
35-
unique module targeted by `apply`, otherwise reuse transform weights where
36-
applicable
34+
:param randomize: True if uniquely randomized transform weights should be used,
35+
otherwise use identical transform weights where applicable
3736
:param requires_grad: True if weights include gradients for training
3837
"""
3938

4039
type: str
4140
apply: List[TransformArgs] = Field(default_factory=list)
42-
randomize_modules: bool = Field(default=False)
41+
randomize: bool = Field(default=False)
4342
requires_grad: bool = Field(default=False)

tests/test_transform/test_transform_scheme.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_basic_scheme():
2424
type="hadamard",
2525
apply=[basic_args],
2626
)
27-
assert not scheme.randomize_modules
27+
assert not scheme.randomize
2828
assert scheme.type == "hadamard"
2929
assert len(scheme.apply) == 1
3030
assert isinstance(scheme.apply[0], TransformArgs)
@@ -43,10 +43,10 @@ def test_multiple_groups_global():
4343
scheme = TransformScheme(
4444
type="hadamard",
4545
apply=[embedding_args, linear_args],
46-
randomize_modules=True,
46+
randomize=True,
4747
)
4848

49-
assert scheme.randomize_modules
49+
assert scheme.randomize
5050
assert scheme.type == "hadamard"
5151
assert len(scheme.apply) == 2
5252
assert isinstance(scheme.apply[0], TransformArgs)
@@ -69,6 +69,6 @@ def test_multiple_groups():
6969
apply=apply,
7070
)
7171

72-
assert not scheme.randomize_modules
72+
assert not scheme.randomize
7373
assert scheme.type == "hadamard"
7474
assert len(scheme.apply) == 20

0 commit comments

Comments
 (0)