Skip to content

Commit e7f08e1

Browse files
committed
Merge branch 'kylesayrs/transform_construct_cache_device' into kylesayrs/transform_apply
2 parents 27bc0b3 + c8f6b53 commit e7f08e1

File tree

9 files changed

+163
-139
lines changed

9 files changed

+163
-139
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,5 +113,6 @@ def _setup_extras() -> Dict:
113113
extras_require=_setup_extras(),
114114
install_requires=_setup_install_requires(),
115115
package_dir={"": "src"},
116+
package_data={"": ["transform/utils/hadamards.safetensors"]},
116117
packages=_setup_packages(),
117118
)

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,7 @@ def apply_quantization_config(
183183
replace_module(model, name, compressed_linear)
184184

185185
# target matched - add layer and scheme to target list
186-
submodule.quantization_scheme = _scheme_from_targets(
187-
target_to_scheme, targets, name
188-
)
186+
submodule.quantization_scheme = scheme
189187

190188
names_to_scheme[name] = submodule.quantization_scheme
191189

src/compressed_tensors/transform/factory/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def _apply_to_module(self, module: Module, args: TransformArgs):
9797
:param args: defines how the transform will be applied to the target module
9898
"""
9999
# create transform as submodule
100-
transform_name = f"{self.name}_{args.location}"
100+
transform_name = f"{self.name}_{args.location.value}"
101101
transform = self.create_transform(module, args)
102102
register_offload_module(module, transform_name, transform) # (1)
103103

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
apply_transform_weight,
2323
get_matrix_size,
2424
)
25-
from compressed_tensors.utils import get_offloaded_device
25+
from compressed_tensors.utils import get_execution_device, get_offloaded_device
2626
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
2727
from torch import Tensor, device, dtype
2828
from torch.nn import Linear, Module, Parameter
@@ -55,14 +55,22 @@ def create_transform(self, module: Module, args: TransformArgs):
5555
size = get_matrix_size(module, args.location)
5656
dtype = module.weight.dtype
5757
device = get_offloaded_device(module)
58+
exec_device = get_execution_device(module)
5859

59-
weight = self.weights[size, dtype, device]
60+
weight = self.weights.get(size, dtype, device, construct_device=exec_device)
6061
perm = self.perms[weight] if self.scheme.randomize else None
6162
return HadamardTransform(weight, perm, args)
6263

63-
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
64-
data = deterministic_hadamard_matrix(size)
65-
data = data.to(dtype=dtype, device=device)
64+
def _create_weight(
65+
self,
66+
size: int,
67+
dtype: dtype,
68+
device: device,
69+
construct_device: device,
70+
) -> Parameter:
71+
# construct on execution device, cache on offload device
72+
data = deterministic_hadamard_matrix(size, dtype, construct_device)
73+
data = data.to(device=device)
6674
return Parameter(data, requires_grad=self.scheme.requires_grad)
6775

6876
def _create_permutation(self, weight: Parameter) -> Parameter:

src/compressed_tensors/transform/factory/random_hadamard.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,14 @@ class RandomHadamardFactory(HadamardFactory):
2828
:param seed: random seed used to transform weight randomization
2929
"""
3030

31-
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
32-
data = random_hadamard_matrix(size, self.generator)
33-
data = data.to(dtype=dtype, device=device)
31+
def _create_weight(
32+
self,
33+
size: int,
34+
dtype: dtype,
35+
device: device,
36+
construct_device: device,
37+
) -> Parameter:
38+
# construct on execution device, cache on offload device
39+
data = random_hadamard_matrix(size, dtype, construct_device, self.generator)
40+
data = data.to(device=device)
3441
return Parameter(data, requires_grad=self.scheme.requires_grad)

src/compressed_tensors/transform/utils/hadamard.py

Lines changed: 91 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -13,149 +13,148 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import Optional, Tuple
16+
from pathlib import Path
17+
from typing import Optional
1718

18-
import numpy
1919
import torch
20+
from safetensors import safe_open
2021

2122

22-
__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix"]
23+
REPO_PATH = Path(__file__).parent / "hadamards.safetensors"
2324

24-
# adapted from:
25-
# https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py
26-
def deterministic_hadamard_matrix(size: int) -> torch.Tensor:
25+
26+
__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix", "is_pow2"]
27+
28+
29+
# note that hadamard matrix multiplication can be accelerated using a library such as
30+
# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master
31+
32+
33+
def deterministic_hadamard_matrix(
34+
size: int,
35+
dtype: torch.dtype = torch.bfloat16,
36+
device: torch.device = torch.device("cpu"),
37+
) -> torch.Tensor:
2738
"""
2839
Construct an n-by-n Hadamard matrix, using Sylvester's construction.
2940
`n` must be a power of 2.
3041
42+
Adapated from https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py # noqa: E501
43+
3144
:param size: order of the matrix, must be a power of 2
45+
:param dtype: data type of matrix
46+
:param device: device to construct matrix on
3247
:return: hadamard matrix of size `size`
3348
"""
3449
if size <= 0:
3550
raise ValueError("Cannot construct deterministic hadamard of size <= 0")
3651

37-
log2 = int(math.log(size, 2))
52+
log2 = int(math.log2(size))
3853
if size != 2**log2:
3954
raise ValueError("Cannot construct deterministic hadamard of size != 2^n")
4055

41-
H = numpy.array([[1]], dtype=int)
56+
H = torch.tensor([[1]], dtype=dtype, device=device)
4257

4358
# Sylvester's construction
44-
for i in range(0, log2):
45-
H = numpy.vstack((numpy.hstack((H, H)), numpy.hstack((H, -H))))
46-
47-
return torch.from_numpy(H / math.sqrt(size))
59+
for _ in range(log2):
60+
H = torch.vstack((torch.hstack((H, H)), torch.hstack((H, -H))))
4861

49-
50-
# adapted from:
51-
# https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py
52-
53-
# TODO: the following library exists for online rotations and should be considered
54-
# in the future:
55-
# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master
62+
return H / math.sqrt(size)
5663

5764

5865
def random_hadamard_matrix(
59-
size: int, gen: Optional[torch.Generator] = None
66+
size: int,
67+
dtype: torch.dtype = torch.bfloat16,
68+
device: torch.device = torch.device("cpu"),
69+
gen: Optional[torch.Generator] = None,
6070
) -> torch.Tensor:
6171
"""
62-
Produces a randomly generated Hadamard matrix.
63-
See https://cornell-relaxml.github.io/quip-sharp/ ,
64-
Section "Randomized Hadamard Transformation"
72+
Produces a randomly generated Hadamard matrix. Differs from
73+
`deterministic_hadamard_matrix` in that this function supports non powers of 2
74+
and randomization using a seeded generator
75+
76+
Adapated from https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py # noqa: E501
77+
Known matrices were retrieved from N. J. A. Sloane's Library of Hadamard Matrices http://www.neilsloane.com/hadamard/ # noqa: E501
6578
6679
:param size: The dimension of the hamadard matrix
80+
:param dtype: data type of matrix
81+
:param device: device to construct matrix on
6782
:param gen: Optional generator random values
6883
:return: randomly generated hadamard matrix
6984
"""
70-
# Benefits: support other shapes / non powers of 2, support randomization
71-
Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=torch.float64)
85+
Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=dtype) # cpu
86+
Q = Q.to(device=device)
7287
Q = Q * 2 - 1
7388
Q = torch.diag(Q)
7489
return _matmul_hadU(Q) / math.sqrt(size)
7590

7691

77-
def _get_hadK(n: int, transpose: bool = False) -> Tuple[torch.Tensor, int]:
78-
# NOTE: we can easily extend the list of supported shapes/sizes
79-
# by adding to these methods
80-
hadK, K = None, None
81-
if n % 20 == 0:
82-
assert _is_pow2(n // 20)
83-
K = 20
84-
hadK = _get_had20().T if transpose else _get_had20()
85-
elif n % 12 == 0:
86-
assert _is_pow2(n // 12)
87-
K = 12
88-
hadK = _get_had12().T if transpose else _get_had12()
89-
else:
90-
assert _is_pow2(n)
91-
K = 1
92+
def is_pow2(n: int) -> bool:
93+
"""
94+
Check if a number is a power of 2
9295
93-
return hadK, K
96+
:param n: number to check
97+
:return: True iff `n` is a power of 2
98+
"""
99+
return n > 0 and (n & (n - 1) == 0)
100+
101+
102+
def _fetch_hadamard_divisor(
103+
n: int,
104+
dtype: torch.dtype,
105+
device: torch.device = torch.device("cpu"),
106+
file_path: str = REPO_PATH,
107+
) -> Optional[torch.Tensor]:
108+
"""
109+
Fetch a known hadamard matrix from the given file path. The returned matrix will
110+
be of of size `k` such that `n / k` is a power of two. Return None if no such
111+
matrix exists.
94112
113+
Note: This function reopens the safetensors file every time it is called.
114+
This is technically inefficient, but a very small runtime cost and simpler
115+
than forcing callers to manage the file open context
116+
117+
:param n: size of known hadamard matrix
118+
:return: a known hadamard matrix of size `n` if one exists, else None
119+
"""
120+
with safe_open(file_path, framework="pt", device=str(device)) as file:
121+
divisors = sorted((int(key) for key in file.keys()), reverse=True)
122+
for divisor in divisors:
123+
if n % divisor == 0 and is_pow2(n // divisor):
124+
return file.get_tensor(str(divisor)).to(dtype=dtype)
125+
126+
return None
127+
128+
129+
def _matmul_hadU(X: torch.Tensor) -> torch.Tensor:
130+
size = X.size(0)
131+
dtype = X.dtype
132+
device = X.device
95133

96-
def _matmul_hadU(X, transpose=False) -> torch.Tensor:
97-
n = X.shape[-1]
98134
# Check if we have the determined hadamard matrix
99-
hadK, K = _get_hadK(n, transpose)
135+
hadK = _fetch_hadamard_divisor(size, dtype, device=device)
136+
if hadK is None:
137+
raise ValueError(f"Cannot construct random hadamard matrix of size {size}")
138+
K = hadK.size(0)
139+
100140
# Reshape diag matrix with randomized -1/+1
101-
input = X.clone().view(-1, n, 1)
141+
input = X.clone().view(-1, size, 1)
102142
output = input.clone()
103-
104-
# for cases when hadK is not predetermined, determine hadamard matrix
105143
while input.shape[1] > K:
106144
input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2])
107145
output = output.view(input.shape)
108146
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
109147
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
110148
output = output.view(input.shape[0], input.shape[1], -1)
111149
(input, output) = (output, input)
150+
assert input.shape[1] == K
112151
del output
113152

114-
# K == 1 when hadK is None; this happens when the size dim (n)
115-
# is not comaptible with any of the maintained hadamard matrices
116-
117-
if K > 1:
118-
# Do not explicitly repeat - OOM
119-
# input = torch.bmm(
120-
# hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
121-
# Use bcast instead
122-
123-
# for cases when hadK is pre-determined
124-
input = hadK.view(1, K, K).to(input) @ input
153+
# Do not explicitly repeat - OOM
154+
# input = torch.bmm(
155+
# hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
156+
# Use bcast instead
157+
input = hadK.view(1, K, K).to(input) @ input
125158

126159
# normalize
127160
return input.view(X.shape)
128-
129-
130-
def _is_pow2(n: int) -> bool:
131-
return (n & (n - 1) == 0) and (n > 0)
132-
133-
134-
def _reshape_bits(packed_bits: numpy.ndarray, original_size: int) -> numpy.ndarray:
135-
had_unpacked = numpy.unpackbits(packed_bits)
136-
had_unpacked = [1 if x == 1 else -1 for x in had_unpacked]
137-
had_unpacked = numpy.array(had_unpacked).reshape((original_size, original_size))
138-
return had_unpacked
139-
140-
141-
# http://www.neilsloane.com/hadamard/index.html
142-
def _get_had12() -> torch.Tensor:
143-
# fmt: off
144-
had_12 = numpy.array([128, 13, 29, 232, 235, 71, 218,
145-
62, 209, 246, 139, 180, 157, 168, 237, 199, 106, 59], dtype=numpy.uint8)
146-
# fmt: on
147-
# TODO: just unpack during apply
148-
had_12_unpacked = _reshape_bits(had_12, original_size=12)
149-
return torch.tensor(had_12_unpacked)
150-
151-
152-
def _get_had20() -> torch.Tensor:
153-
# fmt: off
154-
had_20 = numpy.array([128, 0, 13, 133, 121, 236, 43, 203, 97, 94, 155, 10, 252,
155-
216, 87, 230, 194, 191, 54, 21, 249, 176, 171, 205, 133, 222, 108, 42, 243,
156-
97, 215, 155, 10, 188, 216, 149, 230, 200, 175, 54, 133, 121, 188, 43,
157-
205, 225, 94, 107, 10, 243], dtype=numpy.uint8)
158-
# fmt: on
159-
# TODO: just unpack during apply
160-
had_20_unpacked = _reshape_bits(had_20, original_size=20)
161-
return torch.tensor(had_20_unpacked)
Binary file not shown.

src/compressed_tensors/utils/helpers.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -373,11 +373,16 @@ class ParameterizedDefaultDict(dict):
373373

374374
def __init__(self, default_factory: Callable[[Any], Any]):
375375
self.default_factory = default_factory
376+
self._kwargs = {}
376377

377-
def __missing__(self, key):
378+
def __missing__(self, key: Any) -> Any:
378379
if isinstance(key, tuple):
379-
value = self.default_factory(*key)
380+
value = self.default_factory(*key, **self._kwargs)
380381
else:
381-
value = self.default_factory(key)
382+
value = self.default_factory(key, **self._kwargs)
382383
self[key] = value
383384
return value
385+
386+
def get(self, *args, **kwargs) -> Any:
387+
with patch_attr(self, "_kwargs", kwargs):
388+
return self[args]

0 commit comments

Comments
 (0)