Skip to content

Commit 500af9b

Browse files
committed
use factory_kwargs
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent ad29c15 commit 500af9b

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ def create_transform(self, module: Module, args: TransformArgs):
5555
device = get_offloaded_device(module)
5656
exec_device = get_execution_device(module)
5757

58-
weight = self.weights.get(size, dtype, device, construct_device=exec_device)
58+
factory_kwargs = {"construct_device": exec_device}
59+
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
5960
return HadamardTransform(weight, args)
6061

6162
def _create_weight(

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def create_transform(self, module: Module, args: TransformArgs):
6262
return RandomMatrixTransform(weight, args)
6363

6464
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
65+
# TODO: verify that weight is invertable (has non-zero determinant)
6566
data = torch.rand(
6667
(size, size), generator=self.generator, dtype=dtype, device=device
6768
)

src/compressed_tensors/utils/helpers.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
import contextlib
1616
import warnings
1717
from functools import wraps
18-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
18+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
1919

2020
import numpy
2121
import torch
22+
from frozendict import frozendict
2223
from transformers import AutoConfig
2324

2425

@@ -373,16 +374,23 @@ class ParameterizedDefaultDict(dict):
373374

374375
def __init__(self, default_factory: Callable[[Any], Any]):
375376
self.default_factory = default_factory
376-
self._kwargs = {}
377+
self._factory_kwargs = frozendict()
377378

378379
def __missing__(self, key: Any) -> Any:
379380
if isinstance(key, tuple):
380-
value = self.default_factory(*key, **self._kwargs)
381+
value = self.default_factory(*key, **self._factory_kwargs)
381382
else:
382-
value = self.default_factory(key, **self._kwargs)
383+
value = self.default_factory(key, **self._factory_kwargs)
383384
self[key] = value
384385
return value
385386

386-
def get(self, *args, **kwargs) -> Any:
387-
with patch_attr(self, "_kwargs", kwargs):
387+
def get(self, *args, factory_kwargs: Mapping = frozendict()) -> Any:
388+
"""
389+
Similar to `__getitem__`, but allows passing kwargs to factory function
390+
391+
:param \\*args: args whose tuple will value will be treated as key
392+
:param factory_kwargs: keyword arguments to pass to `default_factory`
393+
:return: dictionary entry for given key
394+
"""
395+
with patch_attr(self, "_factory_kwargs", factory_kwargs):
388396
return self[args]

0 commit comments

Comments
 (0)