Skip to content

Commit f45f3e9

Browse files
committed
add deterministic generation to random matrix
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 02af1e9 commit f45f3e9

File tree

3 files changed

+4
-2
lines changed

3 files changed

+4
-2
lines changed

src/compressed_tensors/transform/factory/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class TransformFactory(RegistryMixin, ABC):
4949
def __init__(self, name: str, scheme: TransformScheme, seed: int = 42):
5050
self.name = name
5151
self.scheme = scheme
52+
self.generator = torch.Generator().manual_seed(seed)
5253
self.seed = seed
5354

5455
@classmethod

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ class HadamardFactory(TransformFactory):
4040

4141
def __init__(self, name: str, scheme: TransformScheme, seed: int = 42):
4242
super().__init__(name, scheme, seed)
43-
self.generator = torch.Generator(device="cpu").manual_seed(seed)
4443
self.weights = ParameterizedDefaultDict(self._create_weight)
4544

4645
def create_transform(self, module: Module, args: TransformArgs):

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ def create_transform(self, module: Module, args: TransformArgs):
6060
return RandomMatrixTransform(weight, args)
6161

6262
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
63-
data = torch.rand((size, size), dtype=dtype, device=device)
63+
data = torch.rand(
64+
(size, size), generator=self.generator, dtype=dtype, device=device
65+
)
6466
return Parameter(data, requires_grad=self.scheme.requires_grad)
6567

6668
def _create_inverse(self, weight: Parameter) -> Parameter:

0 commit comments

Comments
 (0)