Skip to content

Commit dd72b6a

Browse files
committed
make seed optional
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent fbd2939 commit dd72b6a

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

src/compressed_tensors/transform/factory/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from abc import ABC, abstractmethod
16+
from typing import Optional
1617

1718
import torch
1819
import torch.nn.utils.parametrize as P
@@ -46,11 +47,12 @@ class TransformFactory(RegistryMixin, ABC):
4647
:param seed: random seed used to transform weight randomization
4748
"""
4849

49-
def __init__(self, name: str, scheme: TransformScheme, seed: int = 42):
50+
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
5051
self.name = name
5152
self.scheme = scheme
52-
self.generator = torch.Generator().manual_seed(seed)
53-
self.seed = seed
53+
self.generator = torch.Generator()
54+
if seed is not None:
55+
self.generator.manual_seed(seed)
5456

5557
@classmethod
5658
def from_scheme(cls: type[T], scheme: TransformScheme, **kwargs) -> T:

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class HadamardFactory(TransformFactory):
3838
:param seed: random seed used to transform weight randomization
3939
"""
4040

41-
def __init__(self, name: str, scheme: TransformScheme, seed: int = 42):
41+
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
4242
super().__init__(name, scheme, seed)
4343
self.weights = ParameterizedDefaultDict(self._create_weight)
4444

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Optional
16+
1517
import torch
1618
from compressed_tensors.transform import TransformArgs, TransformScheme
1719
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
@@ -35,7 +37,7 @@ class RandomMatrixFactory(TransformFactory):
3537
:param seed: random seed used to transform weight randomization
3638
"""
3739

38-
def __init__(self, name: str, scheme: TransformScheme, seed: int = 42):
40+
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
3941
super().__init__(name, scheme, seed)
4042
self.weights = ParameterizedDefaultDict(self._create_weight)
4143
self.inverses = ParameterizedDefaultDict(self._create_inverse)

0 commit comments

Comments
 (0)