Skip to content

Commit 3c55003

Browse files
committed
Merge branch 'kylesayrs/transform_construct_cache_device' into kylesayrs/transform_apply
2 parents 9745acb + fd3390a commit 3c55003

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def _create_weight(
6969
construct_device: device,
7070
) -> Parameter:
7171
# construct on execution device, cache on offload device
72-
data = deterministic_hadamard_matrix(size, torch.float32, construct_device)
73-
data = data.to(dtype=dtype, device=device)
72+
data = deterministic_hadamard_matrix(size, dtype, construct_device)
73+
data = data.to(device=device)
7474
return Parameter(data, requires_grad=self.scheme.requires_grad)
7575

7676
def _create_permutation(self, weight: Parameter) -> Parameter:

src/compressed_tensors/transform/factory/random_hadamard.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ def _create_weight(
3737
construct_device: device,
3838
) -> Parameter:
3939
# construct on execution device, cache on offload device
40-
data = random_hadamard_matrix(
41-
size, torch.float32, construct_device, self.generator
42-
)
43-
data = data.to(dtype=dtype, device=device)
40+
data = random_hadamard_matrix(size, dtype, construct_device, self.generator)
41+
data = data.to(device=device)
4442
return Parameter(data, requires_grad=self.scheme.requires_grad)

0 commit comments

Comments
 (0)