Skip to content

Commit b715329

Browse files
committed
construct on execution device, cache on offload device
1 parent 310fe6d commit b715329

File tree

3 files changed

+32
-11
lines changed

3 files changed

+32
-11
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
apply_transform_weight,
2323
get_matrix_size,
2424
)
25-
from compressed_tensors.utils import get_offloaded_device
25+
from compressed_tensors.utils import get_execution_device, get_offloaded_device
2626
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
2727
from torch import Tensor, device, dtype
2828
from torch.nn import Linear, Module, Parameter
@@ -41,6 +41,7 @@ class HadamardFactory(TransformFactory):
4141
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
4242
super().__init__(name, scheme, seed)
4343
self.weights = ParameterizedDefaultDict(self._create_weight)
44+
self._exec_device = torch.device("cpu")
4445

4546
def create_transform(self, module: Module, args: TransformArgs):
4647
"""
@@ -54,13 +55,21 @@ def create_transform(self, module: Module, args: TransformArgs):
5455
size = get_matrix_size(module, args.location)
5556
dtype = module.weight.dtype
5657
device = get_offloaded_device(module)
58+
exec_device = get_execution_device(module)
5759

58-
weight = self.weights[size, dtype, device]
60+
weight = self.weights.get(size, dtype, device, construct_device=exec_device)
5961
return HadamardTransform(weight, args)
6062

61-
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
62-
data = deterministic_hadamard_matrix(size, dtype, device)
63-
data = data.to(dtype=dtype, device=device)
63+
def _create_weight(
64+
self,
65+
size: int,
66+
dtype: dtype,
67+
device: device,
68+
construct_device: device,
69+
) -> Parameter:
70+
# construct on execution device, cache on offload device
71+
data = deterministic_hadamard_matrix(size, dtype, construct_device)
72+
data = data.to(device=device)
6473
return Parameter(data, requires_grad=self.scheme.requires_grad)
6574

6675

src/compressed_tensors/transform/factory/random_hadamard.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,14 @@ class RandomHadamardFactory(HadamardFactory):
2828
:param seed: random seed used to transform weight randomization
2929
"""
3030

31-
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
32-
data = random_hadamard_matrix(size, dtype, device, self.generator)
33-
data = data.to(dtype=dtype, device=device)
31+
def _create_weight(
32+
self,
33+
size: int,
34+
dtype: dtype,
35+
device: device,
36+
construct_device: device,
37+
) -> Parameter:
38+
# construct on execution device, cache on offload device
39+
data = random_hadamard_matrix(size, dtype, construct_device, self.generator)
40+
data = data.to(device=device)
3441
return Parameter(data, requires_grad=self.scheme.requires_grad)

src/compressed_tensors/utils/helpers.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -373,11 +373,16 @@ class ParameterizedDefaultDict(dict):
373373

374374
def __init__(self, default_factory: Callable[[Any], Any]):
375375
self.default_factory = default_factory
376+
self._kwargs = {}
376377

377-
def __missing__(self, key):
378+
def __missing__(self, key: Any) -> Any:
378379
if isinstance(key, tuple):
379-
value = self.default_factory(*key)
380+
value = self.default_factory(*key, **self._kwargs)
380381
else:
381-
value = self.default_factory(key)
382+
value = self.default_factory(key, **self._kwargs)
382383
self[key] = value
383384
return value
385+
386+
def get(self, *args, **kwargs) -> Any:
387+
with patch_attr(self, "_kwargs", kwargs):
388+
return self[args]

0 commit comments

Comments
 (0)