Skip to content

Commit 6c63987

Browse files
authored
[Transform] Construct on GPU, cache on CPU (#352)
* use hadamards database file Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * try manifest Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * try setup, update hadamards list Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix setup Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add docstrings, cleanup Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix setup, thank you @dbarbuzzi Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove numpy, add tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * solidify dtype, add gpu tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix docstring Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add device option Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * construct on execution device, cache on offload device Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * save construction device changes for later Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * construct on execution device, cache on offload device * cite nja sloane Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove dreg Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * put on device via safe_open Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * nits and docstrings Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * update docstring Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * construct with same dtype, constructing on fp32 found no difference Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove unnecessary imports Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * use factory_kwargs Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add frozen dict to deps Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * correct typo Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix missing import Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent f7e078f commit 6c63987

File tree

5 files changed

+43
-13
lines changed

5 files changed

+43
-13
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def _setup_packages() -> List:
8888
)
8989

9090
def _setup_install_requires() -> List:
91-
return ["torch>=1.7.0", "transformers", "pydantic>=2.0"]
91+
return ["torch>=1.7.0", "transformers", "pydantic>=2.0", "frozendict"]
9292

9393
def _setup_extras() -> Dict:
9494
return {

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
apply_transform_weight,
2323
get_matrix_size,
2424
)
25-
from compressed_tensors.utils import get_offloaded_device
25+
from compressed_tensors.utils import get_execution_device, get_offloaded_device
2626
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
2727
from torch import Tensor, device, dtype
2828
from torch.nn import Linear, Module, Parameter
@@ -55,14 +55,23 @@ def create_transform(self, module: Module, args: TransformArgs):
5555
size = get_matrix_size(module, args.location)
5656
dtype = module.weight.dtype
5757
device = get_offloaded_device(module)
58+
exec_device = get_execution_device(module)
5859

59-
weight = self.weights[size, dtype, device]
60+
factory_kwargs = {"construct_device": exec_device}
61+
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
6062
perm = self.perms[weight] if self.scheme.randomize else None
6163
return HadamardTransform(weight, perm, args)
6264

63-
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
64-
data = deterministic_hadamard_matrix(size, dtype, device)
65-
data = data.to(dtype=dtype, device=device)
65+
def _create_weight(
66+
self,
67+
size: int,
68+
dtype: dtype,
69+
device: device,
70+
construct_device: device,
71+
) -> Parameter:
72+
# construct on execution device, cache on offload device
73+
data = deterministic_hadamard_matrix(size, dtype, construct_device)
74+
data = data.to(device=device)
6675
return Parameter(data, requires_grad=self.scheme.requires_grad)
6776

6877
def _create_permutation(self, weight: Parameter) -> Parameter:

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def create_transform(self, module: Module, args: TransformArgs):
6262
return RandomMatrixTransform(weight, args)
6363

6464
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
65+
# TODO: verify that weight is invertible (has non-zero determinant)
6566
data = torch.rand(
6667
(size, size), generator=self.generator, dtype=dtype, device=device
6768
)

src/compressed_tensors/transform/factory/random_hadamard.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,14 @@ class RandomHadamardFactory(HadamardFactory):
2828
:param seed: random seed used to transform weight randomization
2929
"""
3030

31-
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
32-
data = random_hadamard_matrix(size, dtype, device, self.generator)
33-
data = data.to(dtype=dtype, device=device)
31+
def _create_weight(
32+
self,
33+
size: int,
34+
dtype: dtype,
35+
device: device,
36+
construct_device: device,
37+
) -> Parameter:
38+
# construct on execution device, cache on offload device
39+
data = random_hadamard_matrix(size, dtype, construct_device, self.generator)
40+
data = data.to(device=device)
3441
return Parameter(data, requires_grad=self.scheme.requires_grad)

src/compressed_tensors/utils/helpers.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
import contextlib
1616
import warnings
1717
from functools import wraps
18-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
18+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
1919

2020
import numpy
2121
import torch
22+
from frozendict import frozendict
2223
from transformers import AutoConfig
2324

2425

@@ -373,11 +374,23 @@ class ParameterizedDefaultDict(dict):
373374

374375
def __init__(self, default_factory: Callable[[Any], Any]):
375376
self.default_factory = default_factory
377+
self._factory_kwargs = frozendict()
376378

377-
def __missing__(self, key):
379+
def __missing__(self, key: Any) -> Any:
378380
if isinstance(key, tuple):
379-
value = self.default_factory(*key)
381+
value = self.default_factory(*key, **self._factory_kwargs)
380382
else:
381-
value = self.default_factory(key)
383+
value = self.default_factory(key, **self._factory_kwargs)
382384
self[key] = value
383385
return value
386+
387+
def get(self, *args, factory_kwargs: Mapping = frozendict()) -> Any:
388+
"""
389+
Similar to `__getitem__`, but allows passing kwargs to factory function
390+
391+
:param \\*args: args whose tuple will value will be treated as key
392+
:param factory_kwargs: keyword arguments to pass to `default_factory`
393+
:return: dictionary entry for given key
394+
"""
395+
with patch_attr(self, "_factory_kwargs", factory_kwargs):
396+
return self[args]

0 commit comments

Comments
 (0)