|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | import math
|
16 |
| -from typing import Optional, Tuple |
| 16 | +from pathlib import Path |
| 17 | +from typing import Optional |
17 | 18 |
|
18 |
| -import numpy |
19 | 19 | import torch
|
| 20 | +from safetensors import safe_open |
20 | 21 |
|
21 | 22 |
|
22 |
| -__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix"] |
| 23 | +REPO_PATH = Path(__file__).parent / "hadamards.safetensors" |
23 | 24 |
|
24 |
| -# adapted from: |
25 |
| -# https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py |
26 |
| -def deterministic_hadamard_matrix(size: int) -> torch.Tensor: |
| 25 | + |
| 26 | +__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix", "is_pow2"] |
| 27 | + |
| 28 | + |
| 29 | +# note that hadamard matrix multiplication can be accelerated using a library such as |
| 30 | +# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master |
| 31 | + |
| 32 | + |
| 33 | +def deterministic_hadamard_matrix( |
| 34 | + size: int, |
| 35 | + dtype: torch.dtype = torch.bfloat16, |
| 36 | + device: torch.device = torch.device("cpu"), |
| 37 | +) -> torch.Tensor: |
27 | 38 | """
|
28 | 39 | Construct an n-by-n Hadamard matrix, using Sylvester's construction.
|
29 | 40 | `n` must be a power of 2.
|
30 | 41 |
|
| 42 | + Adapated from https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py # noqa: E501 |
| 43 | +
|
31 | 44 | :param size: order of the matrix, must be a power of 2
|
| 45 | + :param dtype: data type of matrix |
| 46 | + :param device: device to construct matrix on |
32 | 47 | :return: hadamard matrix of size `size`
|
33 | 48 | """
|
34 | 49 | if size <= 0:
|
35 | 50 | raise ValueError("Cannot construct deterministic hadamard of size <= 0")
|
36 | 51 |
|
37 |
| - log2 = int(math.log(size, 2)) |
| 52 | + log2 = int(math.log2(size)) |
38 | 53 | if size != 2**log2:
|
39 | 54 | raise ValueError("Cannot construct deterministic hadamard of size != 2^n")
|
40 | 55 |
|
41 |
| - H = numpy.array([[1]], dtype=int) |
| 56 | + H = torch.tensor([[1]], dtype=dtype, device=device) |
42 | 57 |
|
43 | 58 | # Sylvester's construction
|
44 |
| - for i in range(0, log2): |
45 |
| - H = numpy.vstack((numpy.hstack((H, H)), numpy.hstack((H, -H)))) |
46 |
| - |
47 |
| - return torch.from_numpy(H / math.sqrt(size)) |
| 59 | + for _ in range(log2): |
| 60 | + H = torch.vstack((torch.hstack((H, H)), torch.hstack((H, -H)))) |
48 | 61 |
|
49 |
| - |
50 |
| -# adapted from: |
51 |
| -# https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py |
52 |
| - |
53 |
| -# TODO: the following library exists for online rotations and should be considered |
54 |
| -# in the future: |
55 |
| -# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master |
| 62 | + return H / math.sqrt(size) |
56 | 63 |
|
57 | 64 |
|
58 | 65 | def random_hadamard_matrix(
|
59 |
| - size: int, gen: Optional[torch.Generator] = None |
| 66 | + size: int, |
| 67 | + dtype: torch.dtype = torch.bfloat16, |
| 68 | + device: torch.device = torch.device("cpu"), |
| 69 | + gen: Optional[torch.Generator] = None, |
60 | 70 | ) -> torch.Tensor:
|
61 | 71 | """
|
62 |
| - Produces a randomly generated Hadamard matrix. |
63 |
| - See https://cornell-relaxml.github.io/quip-sharp/ , |
64 |
| - Section "Randomized Hadamard Transformation" |
| 72 | + Produces a randomly generated Hadamard matrix. Differs from |
| 73 | + `deterministic_hadamard_matrix` in that this function supports non powers of 2 |
| 74 | + and randomization using a seeded generator |
| 75 | +
|
| 76 | + Adapated from https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py # noqa: E501 |
| 77 | + Known matrices were retrieved from N. J. A. Sloane's Library of Hadamard Matrices http://www.neilsloane.com/hadamard/ # noqa: E501 |
65 | 78 |
|
66 | 79 | :param size: The dimension of the hamadard matrix
|
| 80 | + :param dtype: data type of matrix |
| 81 | + :param device: device to construct matrix on |
67 | 82 | :param gen: Optional generator random values
|
68 | 83 | :return: randomly generated hadamard matrix
|
69 | 84 | """
|
70 |
| - # Benefits: support other shapes / non powers of 2, support randomization |
71 |
| - Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=torch.float64) |
| 85 | + Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=dtype) # cpu |
| 86 | + Q = Q.to(device=device) |
72 | 87 | Q = Q * 2 - 1
|
73 | 88 | Q = torch.diag(Q)
|
74 | 89 | return _matmul_hadU(Q) / math.sqrt(size)
|
75 | 90 |
|
76 | 91 |
|
77 |
| -def _get_hadK(n: int, transpose: bool = False) -> Tuple[torch.Tensor, int]: |
78 |
| - # NOTE: we can easily extend the list of supported shapes/sizes |
79 |
| - # by adding to these methods |
80 |
| - hadK, K = None, None |
81 |
| - if n % 20 == 0: |
82 |
| - assert _is_pow2(n // 20) |
83 |
| - K = 20 |
84 |
| - hadK = _get_had20().T if transpose else _get_had20() |
85 |
| - elif n % 12 == 0: |
86 |
| - assert _is_pow2(n // 12) |
87 |
| - K = 12 |
88 |
| - hadK = _get_had12().T if transpose else _get_had12() |
89 |
| - else: |
90 |
| - assert _is_pow2(n) |
91 |
| - K = 1 |
| 92 | +def is_pow2(n: int) -> bool: |
| 93 | + """ |
| 94 | + Check if a number is a power of 2 |
92 | 95 |
|
93 |
| - return hadK, K |
| 96 | + :param n: number to check |
| 97 | + :return: True iff `n` is a power of 2 |
| 98 | + """ |
| 99 | + return n > 0 and (n & (n - 1) == 0) |
| 100 | + |
| 101 | + |
| 102 | +def _fetch_hadamard_divisor( |
| 103 | + n: int, |
| 104 | + dtype: torch.dtype, |
| 105 | + device: torch.device = torch.device("cpu"), |
| 106 | + file_path: str = REPO_PATH, |
| 107 | +) -> Optional[torch.Tensor]: |
| 108 | + """ |
| 109 | + Fetch a known hadamard matrix from the given file path. The returned matrix will |
| 110 | + be of of size `k` such that `n / k` is a power of two. Return None if no such |
| 111 | + matrix exists. |
94 | 112 |
|
| 113 | + Note: This function reopens the safetensors file every time it is called. |
| 114 | + This is technically inefficient, but a very small runtime cost and simpler |
| 115 | + than forcing callers to manage the file open context |
| 116 | +
|
| 117 | + :param n: size of known hadamard matrix |
| 118 | + :return: a known hadamard matrix of size `n` if one exists, else None |
| 119 | + """ |
| 120 | + with safe_open(file_path, framework="pt", device=str(device)) as file: |
| 121 | + divisors = sorted((int(key) for key in file.keys()), reverse=True) |
| 122 | + for divisor in divisors: |
| 123 | + if n % divisor == 0 and is_pow2(n // divisor): |
| 124 | + return file.get_tensor(str(divisor)).to(dtype=dtype) |
| 125 | + |
| 126 | + return None |
| 127 | + |
| 128 | + |
| 129 | +def _matmul_hadU(X: torch.Tensor) -> torch.Tensor: |
| 130 | + size = X.size(0) |
| 131 | + dtype = X.dtype |
| 132 | + device = X.device |
95 | 133 |
|
96 |
| -def _matmul_hadU(X, transpose=False) -> torch.Tensor: |
97 |
| - n = X.shape[-1] |
98 | 134 | # Check if we have the determined hadamard matrix
|
99 |
| - hadK, K = _get_hadK(n, transpose) |
| 135 | + hadK = _fetch_hadamard_divisor(size, dtype, device=device) |
| 136 | + if hadK is None: |
| 137 | + raise ValueError(f"Cannot construct random hadamard matrix of size {size}") |
| 138 | + K = hadK.size(0) |
| 139 | + |
100 | 140 | # Reshape diag matrix with randomized -1/+1
|
101 |
| - input = X.clone().view(-1, n, 1) |
| 141 | + input = X.clone().view(-1, size, 1) |
102 | 142 | output = input.clone()
|
103 |
| - |
104 |
| - # for cases when hadK is not predetermined, determine hadamard matrix |
105 | 143 | while input.shape[1] > K:
|
106 | 144 | input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2])
|
107 | 145 | output = output.view(input.shape)
|
108 | 146 | output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
|
109 | 147 | output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
|
110 | 148 | output = output.view(input.shape[0], input.shape[1], -1)
|
111 | 149 | (input, output) = (output, input)
|
| 150 | + assert input.shape[1] == K |
112 | 151 | del output
|
113 | 152 |
|
114 |
| - # K == 1 when hadK is None; this happens when the size dim (n) |
115 |
| - # is not comaptible with any of the maintained hadamard matrices |
116 |
| - |
117 |
| - if K > 1: |
118 |
| - # Do not explicitly repeat - OOM |
119 |
| - # input = torch.bmm( |
120 |
| - # hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input) |
121 |
| - # Use bcast instead |
122 |
| - |
123 |
| - # for cases when hadK is pre-determined |
124 |
| - input = hadK.view(1, K, K).to(input) @ input |
| 153 | + # Do not explicitly repeat - OOM |
| 154 | + # input = torch.bmm( |
| 155 | + # hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input) |
| 156 | + # Use bcast instead |
| 157 | + input = hadK.view(1, K, K).to(input) @ input |
125 | 158 |
|
126 | 159 | # normalize
|
127 | 160 | return input.view(X.shape)
|
128 |
| - |
129 |
| - |
130 |
| -def _is_pow2(n: int) -> bool: |
131 |
| - return (n & (n - 1) == 0) and (n > 0) |
132 |
| - |
133 |
| - |
134 |
| -def _reshape_bits(packed_bits: numpy.ndarray, original_size: int) -> numpy.ndarray: |
135 |
| - had_unpacked = numpy.unpackbits(packed_bits) |
136 |
| - had_unpacked = [1 if x == 1 else -1 for x in had_unpacked] |
137 |
| - had_unpacked = numpy.array(had_unpacked).reshape((original_size, original_size)) |
138 |
| - return had_unpacked |
139 |
| - |
140 |
| - |
141 |
| -# http://www.neilsloane.com/hadamard/index.html |
142 |
| -def _get_had12() -> torch.Tensor: |
143 |
| - # fmt: off |
144 |
| - had_12 = numpy.array([128, 13, 29, 232, 235, 71, 218, |
145 |
| - 62, 209, 246, 139, 180, 157, 168, 237, 199, 106, 59], dtype=numpy.uint8) |
146 |
| - # fmt: on |
147 |
| - # TODO: just unpack during apply |
148 |
| - had_12_unpacked = _reshape_bits(had_12, original_size=12) |
149 |
| - return torch.tensor(had_12_unpacked) |
150 |
| - |
151 |
| - |
152 |
| -def _get_had20() -> torch.Tensor: |
153 |
| - # fmt: off |
154 |
| - had_20 = numpy.array([128, 0, 13, 133, 121, 236, 43, 203, 97, 94, 155, 10, 252, |
155 |
| - 216, 87, 230, 194, 191, 54, 21, 249, 176, 171, 205, 133, 222, 108, 42, 243, |
156 |
| - 97, 215, 155, 10, 188, 216, 149, 230, 200, 175, 54, 133, 121, 188, 43, |
157 |
| - 205, 225, 94, 107, 10, 243], dtype=numpy.uint8) |
158 |
| - # fmt: on |
159 |
| - # TODO: just unpack during apply |
160 |
| - had_20_unpacked = _reshape_bits(had_20, original_size=20) |
161 |
| - return torch.tensor(had_20_unpacked) |
0 commit comments