Skip to content

Commit 310fe6d

Browse files
committed
save construction device changes for later
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 5a887f4 commit 310fe6d

File tree

3 files changed

+11
-32
lines changed

3 files changed

+11
-32
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
apply_transform_weight,
2323
get_matrix_size,
2424
)
25-
from compressed_tensors.utils import get_execution_device, get_offloaded_device
25+
from compressed_tensors.utils import get_offloaded_device
2626
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
2727
from torch import Tensor, device, dtype
2828
from torch.nn import Linear, Module, Parameter
@@ -41,7 +41,6 @@ class HadamardFactory(TransformFactory):
4141
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
4242
super().__init__(name, scheme, seed)
4343
self.weights = ParameterizedDefaultDict(self._create_weight)
44-
self._exec_device = torch.device("cpu")
4544

4645
def create_transform(self, module: Module, args: TransformArgs):
4746
"""
@@ -55,21 +54,13 @@ def create_transform(self, module: Module, args: TransformArgs):
5554
size = get_matrix_size(module, args.location)
5655
dtype = module.weight.dtype
5756
device = get_offloaded_device(module)
58-
exec_device = get_execution_device(module)
5957

60-
weight = self.weights.get(size, dtype, device, construct_device=exec_device)
58+
weight = self.weights[size, dtype, device]
6159
return HadamardTransform(weight, args)
6260

63-
def _create_weight(
64-
self,
65-
size: int,
66-
dtype: dtype,
67-
device: device,
68-
construct_device: device,
69-
) -> Parameter:
70-
# construct on execution device, cache on offload device
71-
data = deterministic_hadamard_matrix(size, dtype, construct_device)
72-
data = data.to(device=device)
61+
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
62+
data = deterministic_hadamard_matrix(size, dtype, device)
63+
data = data.to(dtype=dtype, device=device)
7364
return Parameter(data, requires_grad=self.scheme.requires_grad)
7465

7566

src/compressed_tensors/transform/factory/random_hadamard.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,7 @@ class RandomHadamardFactory(HadamardFactory):
2828
:param seed: random seed used to transform weight randomization
2929
"""
3030

31-
def _create_weight(
32-
self,
33-
size: int,
34-
dtype: dtype,
35-
device: device,
36-
construct_device: device,
37-
) -> Parameter:
38-
# construct on execution device, cache on offload device
39-
data = random_hadamard_matrix(size, dtype, construct_device, self.generator)
40-
data = data.to(device=device)
31+
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
32+
data = random_hadamard_matrix(size, dtype, device, self.generator)
33+
data = data.to(dtype=dtype, device=device)
4134
return Parameter(data, requires_grad=self.scheme.requires_grad)

src/compressed_tensors/utils/helpers.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -373,16 +373,11 @@ class ParameterizedDefaultDict(dict):
373373

374374
def __init__(self, default_factory: Callable[[Any], Any]):
375375
self.default_factory = default_factory
376-
self._kwargs = {}
377376

378-
def __missing__(self, key: Any) -> Any:
377+
def __missing__(self, key):
379378
if isinstance(key, tuple):
380-
value = self.default_factory(*key, **self._kwargs)
379+
value = self.default_factory(*key)
381380
else:
382-
value = self.default_factory(key, **self._kwargs)
381+
value = self.default_factory(key)
383382
self[key] = value
384383
return value
385-
386-
def get(self, *args, **kwargs) -> Any:
387-
with patch_attr(self, "_kwargs", kwargs):
388-
return self[args]

0 commit comments

Comments
 (0)