Skip to content

Commit 0a4fea5

Browse files
committed
Merge branch 'kylesayrs/transform_construct_cache_device' into kylesayrs/transform_apply
2 parents 06e0346 + 8e36540 commit 0a4fea5

File tree

5 files changed

+18
-9
lines changed

5 files changed

+18
-9
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def _setup_packages() -> List:
8888
)
8989

9090
def _setup_install_requires() -> List:
91-
return ["torch>=1.7.0", "transformers", "pydantic>=2.0"]
91+
return ["torch>=1.7.0", "transformers", "pydantic>=2.0", "frozendict"]
9292

9393
def _setup_extras() -> Dict:
9494
return {

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ def create_transform(self, module: Module, args: TransformArgs):
5757
device = get_offloaded_device(module)
5858
exec_device = get_execution_device(module)
5959

60-
weight = self.weights.get(size, dtype, device, construct_device=exec_device)
60+
factory_kwargs = {"construct_device": exec_device}
61+
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
6162
perm = self.perms[weight] if self.scheme.randomize else None
6263
return HadamardTransform(weight, perm, args)
6364

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def create_transform(self, module: Module, args: TransformArgs):
6262
return RandomMatrixTransform(weight, args)
6363

6464
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
65+
# TODO: verify that weight is invertable (has non-zero determinant)
6566
data = torch.rand(
6667
(size, size), generator=self.generator, dtype=dtype, device=device
6768
)

src/compressed_tensors/transform/factory/random_hadamard.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import torch
1615
from compressed_tensors.transform import HadamardFactory, TransformFactory
1716
from compressed_tensors.transform.utils.hadamard import random_hadamard_matrix
1817
from torch import device, dtype

src/compressed_tensors/utils/helpers.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
import contextlib
1616
import warnings
1717
from functools import wraps
18-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
18+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
1919

2020
import numpy
2121
import torch
22+
from frozendict import frozendict
2223
from transformers import AutoConfig
2324

2425

@@ -373,16 +374,23 @@ class ParameterizedDefaultDict(dict):
373374

374375
def __init__(self, default_factory: Callable[[Any], Any]):
375376
self.default_factory = default_factory
376-
self._kwargs = {}
377+
self._factory_kwargs = frozendict()
377378

378379
def __missing__(self, key: Any) -> Any:
379380
if isinstance(key, tuple):
380-
value = self.default_factory(*key, **self._kwargs)
381+
value = self.default_factory(*key, **self._factory_kwargs)
381382
else:
382-
value = self.default_factory(key, **self._kwargs)
383+
value = self.default_factory(key, **self._factory_kwargs)
383384
self[key] = value
384385
return value
385386

386-
def get(self, *args, **kwargs) -> Any:
387-
with patch_attr(self, "_kwargs", kwargs):
387+
def get(self, *args, factory_kwargs: Mapping = frozendict()) -> Any:
388+
"""
389+
Similar to `__getitem__`, but allows passing kwargs to factory function
390+
391+
:param \\*args: args whose tuple will value will be treated as key
392+
:param factory_kwargs: keyword arguments to pass to `default_factory`
393+
:return: dictionary entry for given key
394+
"""
395+
with patch_attr(self, "_factory_kwargs", factory_kwargs):
388396
return self[args]

0 commit comments

Comments
 (0)