Skip to content

Commit 049386c

Browse files
committed
move to cuda for fp8 tests
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent ba48863 commit 049386c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _validate_shard_shapes(sharded_values, sharded_bitmask, expected_shapes):
4747

4848
def validate_compression(dense_matrix, decompressed_tensor):
4949
"""Validate that the decompressed tensor matches the original dense matrix."""
50-
if decompressed_tensor.device == FP8_DTYPE:
50+
if decompressed_tensor.dtype == FP8_DTYPE:
5151
decompressed_tensor = decompressed_tensor.to("cuda")
5252
dense_matrix = dense_matrix.to(decompressed_tensor.device)
5353
assert dense_matrix.dtype == decompressed_tensor.dtype, "Dtype mismatch"

0 commit comments

Comments
 (0)