Skip to content

Commit 02af1e9

Browse files
committed
use random seeds, rename matrix multiply
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 7c02bb2 commit 02af1e9

File tree

4 files changed

+6
-5
lines changed

4 files changed

+6
-5
lines changed

src/compressed_tensors/transform/factory/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ def apply_to_model(self, model: Module):
8282
:param model: module to apply transforms to
8383
"""
8484
for arg in self.scheme.apply:
85-
for path, module in list(model.named_modules()):
86-
if is_target(path, module, arg.targets, arg.ignore):
85+
for name, module in list(model.named_modules()):
86+
if is_target(name, module, arg.targets, arg.ignore):
8787
self._apply_to_module(module, arg)
8888

8989
def _apply_to_module(self, module: Module, args: TransformArgs):

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class HadamardFactory(TransformFactory):
4040

4141
def __init__(self, name: str, scheme: TransformScheme, seed: int = 42):
4242
super().__init__(name, scheme, seed)
43+
self.generator = torch.Generator(device="cpu").manual_seed(seed)
4344
self.weights = ParameterizedDefaultDict(self._create_weight)
4445

4546
def create_transform(self, module: Module, args: TransformArgs):
@@ -59,7 +60,7 @@ def create_transform(self, module: Module, args: TransformArgs):
5960
return HadamardTransform(weight, args)
6061

6162
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
62-
data = deterministic_hadamard_matrix(size) # TODO: seed=self.seed
63+
data = deterministic_hadamard_matrix(size)
6364
data = data.to(dtype=dtype, device=device)
6465
return Parameter(data, requires_grad=self.scheme.requires_grad)
6566

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from torch.nn import Linear, Module, Parameter
2626

2727

28-
@TransformFactory.register("matrix-mul")
28+
@TransformFactory.register("random-matrix")
2929
class RandomMatrixFactory(TransformFactory):
3030
"""
3131
Factory used to apply random matrix transforms to a model

src/compressed_tensors/transform/factory/random_hadamard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,6 @@ class RandomHadamardFactory(HadamardFactory):
2929
"""
3030

3131
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
32-
data = random_hadamard_matrix(size) # seed
32+
data = random_hadamard_matrix(size, self.generator)
3333
data = data.to(dtype=dtype, device=device)
3434
return Parameter(data, requires_grad=self.scheme.requires_grad)

0 commit comments

Comments
 (0)