Skip to content

Commit 8d613b3

Browse files
committed
add permutations
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 809e367 commit 8d613b3

File tree

4 files changed

+52
-36
lines changed

4 files changed

+52
-36
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional
15+
from typing import Union
1616

1717
import torch
1818
from compressed_tensors.transform import TransformArgs, TransformScheme
1919
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
2020
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
2121
from compressed_tensors.transform.utils.utils import (
22+
apply_permutation,
2223
apply_transform_weight,
2324
get_matrix_size,
2425
)
@@ -41,6 +42,7 @@ class HadamardFactory(TransformFactory):
4142
def __init__(self, name: str, scheme: TransformScheme, seed: int = 42):
4243
super().__init__(name, scheme, seed)
4344
self.weights = ParameterizedDefaultDict(self._create_weight)
45+
self.perms = ParameterizedDefaultDict(self._create_permutation)
4446

4547
def create_transform(self, module: Module, args: TransformArgs):
4648
"""
@@ -56,24 +58,35 @@ def create_transform(self, module: Module, args: TransformArgs):
5658
device = get_offloaded_device(module)
5759

5860
weight = self.weights[size, dtype, device]
59-
return HadamardTransform(weight, args)
61+
perm = self.perms[module, weight] if self.scheme.randomize_modules else None
62+
return HadamardTransform(weight, perm, args)
6063

6164
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
6265
data = torch.tensor(deterministic_hadamard_matrix(size)) # TODO: seed=self.seed
6366
data = data.to(dtype=dtype, device=device)
6467
return Parameter(data, requires_grad=self.scheme.requires_grad)
6568

69+
def _create_permutation(self, module: Module, weight: Parameter) -> Parameter:
70+
data = torch.randperm(weight.size(0))
71+
return Parameter(data, requires_grad=False)
72+
6673

6774
class HadamardTransform(TransformBase):
68-
def __init__(self, weight: Parameter, args: TransformArgs):
75+
def __init__(
76+
self, weight: Parameter, perm: Union[Parameter, None], args: TransformArgs
77+
):
6978
super().__init__()
7079
self.weight = weight
80+
self.perm = perm
7181
self.args = args
7282

7383
def forward(self, value: Tensor) -> Tensor:
74-
if not self.args.inverse:
75-
weight = self.weight
76-
else:
77-
weight = self.weight.T / self.weight.size(0)
84+
weight = self.weight
85+
86+
if self.perm is not None:
87+
weight = apply_permutation(weight, self.perm)
88+
89+
if self.args.inverse:
90+
weight = weight.T / weight.size(0)
7891

7992
return apply_transform_weight(weight, value, self.args.location)

src/compressed_tensors/transform/utils/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from compressed_tensors.transform import TransformLocation
1717

1818

19-
__all__ = ["get_matrix_size", "apply_transform_weight"]
19+
__all__ = ["get_matrix_size", "apply_transform_weight", "apply_permutation"]
2020

2121

2222
def get_matrix_size(module: torch.nn.Module, location: TransformLocation) -> int:
@@ -83,3 +83,10 @@ def apply_transform_weight(
8383

8484
elif location == TransformLocation.OUTPUT:
8585
return value @ weight
86+
87+
88+
def apply_permutation(weight: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
89+
weight = weight.clone()
90+
diag_indices = torch.arange(weight.size(0))
91+
weight[diag_indices, diag_indices] = weight.diagonal()[perm]
92+
return weight

tests/test_transform/factory/test_correctness.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,18 @@
1919
TransformFactory,
2020
TransformScheme,
2121
)
22-
from compressed_tensors.utils import align_modules, force_cpu_offload
22+
from compressed_tensors.utils import force_cpu_offload
2323
from tests.testing_utils import requires_accelerate, requires_gpu
2424

2525

26+
_test_schemes = [
27+
TransformScheme(type=name) for name in TransformFactory.registered_names()
28+
] + [
29+
TransformScheme(type=name, randomize_modules=True)
30+
for name in TransformFactory.registered_names()
31+
]
32+
33+
2634
class TransformableModel(torch.nn.Module):
2735
def __init__(self, *sizes):
2836
super().__init__()
@@ -37,10 +45,7 @@ def forward(self, x):
3745
return x
3846

3947

40-
@pytest.mark.parametrize(
41-
"scheme",
42-
[TransformScheme(type=name) for name in TransformFactory.registered_names()],
43-
)
48+
@pytest.mark.parametrize("scheme", _test_schemes)
4449
def test_correctness_linear(scheme):
4550
size = (4, 8)
4651
module = torch.nn.Linear(*size, bias=True)
@@ -68,10 +73,7 @@ def test_correctness_linear(scheme):
6873
torch.allclose(true_output, output, atol=1e-7, rtol=0.0)
6974

7075

71-
@pytest.mark.parametrize(
72-
"scheme",
73-
[TransformScheme(type=name) for name in TransformFactory.registered_names()],
74-
)
76+
@pytest.mark.parametrize("scheme", _test_schemes)
7577
def test_correctness_model(scheme, offload=False):
7678
# load model
7779
model = TransformableModel(2, 4, 8, 16)
@@ -99,9 +101,6 @@ def test_correctness_model(scheme, offload=False):
99101

100102
@requires_gpu
101103
@requires_accelerate()
102-
@pytest.mark.parametrize(
103-
"scheme",
104-
[TransformScheme(type=name) for name in TransformFactory.registered_names()],
105-
)
104+
@pytest.mark.parametrize("scheme", _test_schemes)
106105
def test_correctness_model_offload(scheme):
107106
test_correctness_model(scheme, offload=True)

tests/test_transform/factory/test_memory.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@
2626
from tests.testing_utils import requires_accelerate, requires_gpu
2727

2828

29+
_test_schemes = [
30+
TransformScheme(type=name) for name in TransformFactory.registered_names()
31+
] + [
32+
TransformScheme(type=name, randomize_modules=True)
33+
for name in TransformFactory.registered_names()
34+
]
35+
36+
2937
class TransformableModel(torch.nn.Module):
3038
def __init__(self, *sizes):
3139
super().__init__()
@@ -40,10 +48,7 @@ def forward(self, x):
4048
return x
4149

4250

43-
@pytest.mark.parametrize(
44-
"scheme",
45-
[TransformScheme(type=name) for name in TransformFactory.registered_names()],
46-
)
51+
@pytest.mark.parametrize("scheme", _test_schemes)
4752
def test_memory_sharing(scheme, offload=False):
4853
# load scheme and factory
4954
scheme = TransformScheme(
@@ -93,20 +98,12 @@ def test_memory_sharing(scheme, offload=False):
9398

9499
@requires_gpu
95100
@requires_accelerate()
96-
@pytest.mark.parametrize(
97-
"scheme",
98-
[TransformScheme(type=name) for name in TransformFactory.registered_names()],
99-
)
101+
@pytest.mark.parametrize("scheme", _test_schemes)
100102
def test_memory_sharing_offload(scheme):
101103
test_memory_sharing(scheme, offload=True)
102104

103105

104-
@pytest.mark.parametrize(
105-
"scheme",
106-
[
107-
TransformScheme(type=name, requires_grad=True)
108-
for name in TransformFactory.registered_names()
109-
],
110-
)
106+
@pytest.mark.parametrize("scheme", _test_schemes)
111107
def test_memory_sharing_training(scheme):
108+
scheme.requires_grad = True
112109
test_memory_sharing(scheme, offload=False)

0 commit comments

Comments
 (0)