Skip to content

Commit 226f367

Browse files
committed
merge with construct: construct in float32
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 955f2f5 commit 226f367

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def _create_weight(
6969
construct_device: device,
7070
) -> Parameter:
7171
# construct on execution device, cache on offload device
72-
data = deterministic_hadamard_matrix(size, dtype, construct_device)
73-
data = data.to(device=device)
72+
data = deterministic_hadamard_matrix(size, torch.float32, construct_device)
73+
data = data.to(dtype=dtype, device=device)
7474
return Parameter(data, requires_grad=self.scheme.requires_grad)
7575

7676
def _create_permutation(self, weight: Parameter) -> Parameter:

src/compressed_tensors/transform/factory/random_hadamard.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import torch
1516
from compressed_tensors.transform import HadamardFactory, TransformFactory
1617
from compressed_tensors.transform.utils.hadamard import random_hadamard_matrix
1718
from torch import device, dtype
@@ -36,6 +37,8 @@ def _create_weight(
3637
construct_device: device,
3738
) -> Parameter:
3839
# construct on execution device, cache on offload device
39-
data = random_hadamard_matrix(size, dtype, construct_device, self.generator)
40-
data = data.to(device=device)
40+
data = random_hadamard_matrix(
41+
size, torch.float32, construct_device, self.generator
42+
)
43+
data = data.to(dtype=dtype, device=device)
4144
return Parameter(data, requires_grad=self.scheme.requires_grad)

0 commit comments

Comments
 (0)