Skip to content

Commit 5a887f4

Browse files
committed
construct on execution device, cache on offload device
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent fbaf47a commit 5a887f4

File tree

5 files changed

+37
-15
lines changed

5 files changed

+37
-15
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
apply_transform_weight,
2323
get_matrix_size,
2424
)
25-
from compressed_tensors.utils import get_offloaded_device
25+
from compressed_tensors.utils import get_execution_device, get_offloaded_device
2626
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
2727
from torch import Tensor, device, dtype
2828
from torch.nn import Linear, Module, Parameter
@@ -41,6 +41,7 @@ class HadamardFactory(TransformFactory):
4141
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
4242
super().__init__(name, scheme, seed)
4343
self.weights = ParameterizedDefaultDict(self._create_weight)
44+
self._exec_device = torch.device("cpu")
4445

4546
def create_transform(self, module: Module, args: TransformArgs):
4647
"""
@@ -54,12 +55,20 @@ def create_transform(self, module: Module, args: TransformArgs):
5455
size = get_matrix_size(module, args.location)
5556
dtype = module.weight.dtype
5657
device = get_offloaded_device(module)
58+
exec_device = get_execution_device(module)
5759

58-
weight = self.weights[size, dtype, device]
60+
weight = self.weights.get(size, dtype, device, construct_device=exec_device)
5961
return HadamardTransform(weight, args)
6062

61-
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
62-
data = deterministic_hadamard_matrix(size, dtype=dtype)
63+
def _create_weight(
64+
self,
65+
size: int,
66+
dtype: dtype,
67+
device: device,
68+
construct_device: device,
69+
) -> Parameter:
70+
# construct on execution device, cache on offload device
71+
data = deterministic_hadamard_matrix(size, dtype, construct_device)
6372
data = data.to(device=device)
6473
return Parameter(data, requires_grad=self.scheme.requires_grad)
6574

src/compressed_tensors/transform/factory/random_hadamard.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,14 @@ class RandomHadamardFactory(HadamardFactory):
2828
:param seed: random seed used to transform weight randomization
2929
"""
3030

31-
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
32-
data = random_hadamard_matrix(size, dtype=dtype, gen=self.generator)
31+
def _create_weight(
32+
self,
33+
size: int,
34+
dtype: dtype,
35+
device: device,
36+
construct_device: device,
37+
) -> Parameter:
38+
# construct on execution device, cache on offload device
39+
data = random_hadamard_matrix(size, dtype, construct_device, self.generator)
3340
data = data.to(device=device)
3441
return Parameter(data, requires_grad=self.scheme.requires_grad)

src/compressed_tensors/transform/utils/hadamard.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,17 @@ def random_hadamard_matrix(
7171
See https://cornell-relaxml.github.io/quip-sharp/ ,
7272
Section "Randomized Hadamard Transformation"
7373
74+
Improves upon deterministic_hadamard_matrix
75+
in that this supports non powers of 2 and random seeds
76+
7477
Adapated from https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py # noqa: E501
7578
7679
:param size: The dimension of the hamadard matrix
7780
:param gen: Optional generator random values
7881
:return: randomly generated hadamard matrix
7982
"""
80-
# Benefits: support other shapes / non powers of 2, support randomization
81-
Q = torch.randint(
82-
low=0, high=2, size=(size,), generator=gen, dtype=dtype, device=device
83-
)
83+
Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=dtype) # cpu
84+
Q = Q.to(device=device)
8485
Q = Q * 2 - 1
8586
Q = torch.diag(Q)
8687
return _matmul_hadU(Q) / math.sqrt(size)

src/compressed_tensors/utils/helpers.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -373,11 +373,16 @@ class ParameterizedDefaultDict(dict):
373373

374374
def __init__(self, default_factory: Callable[[Any], Any]):
375375
self.default_factory = default_factory
376+
self._kwargs = {}
376377

377-
def __missing__(self, key):
378+
def __missing__(self, key: Any) -> Any:
378379
if isinstance(key, tuple):
379-
value = self.default_factory(*key)
380+
value = self.default_factory(*key, **self._kwargs)
380381
else:
381-
value = self.default_factory(key)
382+
value = self.default_factory(key, **self._kwargs)
382383
self[key] = value
383384
return value
385+
386+
def get(self, *args, **kwargs) -> Any:
387+
with patch_attr(self, "_kwargs", kwargs):
388+
return self[args]

tests/test_transform/utils/test_hadamard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@
4444
@pytest.mark.parametrize("size", _sizes_to_test)
4545
def test_random_hadamard_matrix_compliant(size):
4646
# (H / sqrt(n))(H.T / sqrt(n)) == I
47-
had_matrix = random_hadamard_matrix(size, device="cuda")
48-
product = had_matrix @ had_matrix.T
47+
matrix = random_hadamard_matrix(size, device="cuda")
48+
product = matrix @ matrix.T
4949
eye = torch.eye(size, dtype=product.dtype, device="cuda")
5050
assert torch.allclose(product, eye, atol=_atol)
5151

0 commit comments

Comments
 (0)