31
31
32
32
33
33
def deterministic_hadamard_matrix (
34
- size : int , dtype : torch .dtype = torch .bfloat16
34
+ size : int ,
35
+ dtype : torch .dtype = torch .bfloat16 ,
36
+ device : torch .device = torch .device ("cpu" ),
35
37
) -> torch .Tensor :
36
38
"""
37
39
Construct an n-by-n Hadamard matrix, using Sylvester's construction.
@@ -49,7 +51,7 @@ def deterministic_hadamard_matrix(
49
51
if size != 2 ** log2 :
50
52
raise ValueError ("Cannot construct deterministic hadamard of size != 2^n" )
51
53
52
- H = torch .tensor ([[1 ]], dtype = dtype )
54
+ H = torch .tensor ([[1 ]], dtype = dtype , device = device )
53
55
54
56
# Sylvester's construction
55
57
for _ in range (0 , log2 ):
@@ -61,6 +63,7 @@ def deterministic_hadamard_matrix(
61
63
def random_hadamard_matrix (
62
64
size : int ,
63
65
dtype : torch .dtype = torch .bfloat16 ,
66
+ device : torch .device = torch .device ("cpu" ),
64
67
gen : Optional [torch .Generator ] = None ,
65
68
) -> torch .Tensor :
66
69
"""
@@ -75,7 +78,9 @@ def random_hadamard_matrix(
75
78
:return: randomly generated hadamard matrix
76
79
"""
77
80
# Benefits: support other shapes / non powers of 2, support randomization
78
- Q = torch .randint (low = 0 , high = 2 , size = (size ,), generator = gen , dtype = dtype )
81
+ Q = torch .randint (
82
+ low = 0 , high = 2 , size = (size ,), generator = gen , dtype = dtype , device = device
83
+ )
79
84
Q = Q * 2 - 1
80
85
Q = torch .diag (Q )
81
86
return _matmul_hadU (Q ) / math .sqrt (size )
@@ -86,16 +91,18 @@ def is_pow2(n: int) -> bool:
86
91
87
92
88
93
def _get_known_divisor (
89
- n : int , dtype : torch .dtype , file_path : str = REPO_PATH
94
+ n : int ,
95
+ dtype : torch .dtype ,
96
+ device : torch .device = torch .device ("cpu" ),
97
+ file_path : str = REPO_PATH ,
90
98
) -> Optional [torch .Tensor ]:
91
99
"""
92
100
Fetch a known hadamard matrix from the given file path. The returned matrix will
93
101
be of of size `k` such that `n / k` is a power of two. Return None if no such
94
102
matrix exists.
95
103
96
104
Note: This function reopens the safetensors file every time it is called.
97
- This is inefficient, but inconsequential because hadamards are typically
98
- cached by size through the factory that produced them. This is also simpler
105
+ This is technically inefficient, but a very small runtime cost and simpler
99
106
than forcing callers to manage the file open context
100
107
101
108
:param n: size of known hadamard matrix
@@ -105,17 +112,18 @@ def _get_known_divisor(
105
112
divisors = sorted ([int (key ) for key in file .keys ()], reverse = True )
106
113
for divisor in divisors :
107
114
if n % divisor == 0 and is_pow2 (n // divisor ):
108
- return file .get_tensor (str (divisor )).to (dtype = dtype )
115
+ return file .get_tensor (str (divisor )).to (dtype = dtype , device = device )
109
116
110
117
return None
111
118
112
119
113
120
def _matmul_hadU (X : torch .Tensor ) -> torch .Tensor :
114
- size = X .shape [ - 1 ]
121
+ size = X .size ( 0 )
115
122
dtype = X .dtype
123
+ device = X .device
116
124
117
125
# Check if we have the determined hadamard matrix
118
- hadK = _get_known_divisor (size , dtype )
126
+ hadK = _get_known_divisor (size , dtype , device = device )
119
127
if hadK is None :
120
128
raise ValueError (f"Cannot construct random hadamard matrix of size { size } " )
121
129
K = hadK .size (0 )
@@ -130,6 +138,7 @@ def _matmul_hadU(X: torch.Tensor) -> torch.Tensor:
130
138
output [:, :, 1 , :] = input [:, :, 0 , :] - input [:, :, 1 , :]
131
139
output = output .view (input .shape [0 ], input .shape [1 ], - 1 )
132
140
(input , output ) = (output , input )
141
+ assert input .shape [1 ] == K
133
142
del output
134
143
135
144
# Do not explicitly repeat - OOM
0 commit comments