Skip to content

Commit 893e189

Browse files
committed
Clean up implementation and add comprehensive tests
- Remove unnecessary padding from pack_bitmasks - Add comprehensive test suite in tests/test_sparse_optimization.py - Remove redundant comments - Direct bit packing without intermediate operations
1 parent 5172c0c commit 893e189

File tree

3 files changed

+229
-24
lines changed

3 files changed

+229
-24
lines changed

src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,9 @@ def from_dense(
9090
:return: instantiated compressed tensor
9191
"""
9292
shape = list(tensor.shape)
93-
# Keep tensor on its original device for faster processing
9493
compressed, bitmask = sparse24_bitmask_compress(
9594
tensor, sparsity_structure=sparsity_structure
9695
)
97-
# Move to CPU only at the end if needed for storage
9896
return Sparse24BitMaskTensor(
9997
shape=shape,
10098
compressed=compressed.cpu() if compressed.is_cuda else compressed,
@@ -235,7 +233,6 @@ def get_24_bytemasks(tensor):
235233

236234
reshaped_tensor = tensor.view(-1, 4)
237235
abs_tensor = reshaped_tensor.abs()
238-
# Use largest=True, sorted=False for better performance
239236
topk_indices = abs_tensor.topk(2, dim=1, largest=True, sorted=False).indices
240237
mask = torch.zeros_like(reshaped_tensor, dtype=torch.bool)
241238
mask.scatter_(1, topk_indices, True)

src/compressed_tensors/utils/helpers.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -301,46 +301,31 @@ def pack_bitmasks(bytemasks: torch.Tensor) -> torch.Tensor:
301301
:param bytemasks: mask tensor where each byte corresponds to a weight
302302
:return: mask tensor where each bit corresounds to a weight
303303
"""
304-
# Try PyTorch-based implementation first to avoid CPU transfer
305304
try:
306305
device = bytemasks.device
307306
dtype = bytemasks.dtype
308307

309-
# Ensure input is boolean or can be treated as boolean
310308
if dtype != torch.bool:
311309
bytemasks = bytemasks.bool()
312310

313311
rows, cols = bytemasks.shape
314-
packed_cols = (cols + 7) // 8 # ceil(cols/8)
312+
packed_cols = (cols + 7) // 8
315313

316-
# Convert boolean mask to uint8
317314
bytemasks_uint8 = bytemasks.to(torch.uint8)
318-
319-
# Pad to multiple of 8 if needed
320-
if cols % 8 != 0:
321-
padding = 8 - (cols % 8)
322-
bytemasks_uint8 = torch.nn.functional.pad(bytemasks_uint8, (0, padding))
323-
324-
# Reshape to group by 8 bits
325-
reshaped = bytemasks_uint8.view(rows, packed_cols, 8)
326-
327-
# Pack bits (little endian) - use bitwise operations
328315
packed = torch.zeros(rows, packed_cols, dtype=torch.uint8, device=device)
329-
for i in range(8):
330-
packed |= reshaped[:, :, i] << i
316+
317+
# Pack bits directly without padding
318+
for i in range(cols):
319+
packed[:, i // 8] |= bytemasks_uint8[:, i] << (i % 8)
331320

332321
return packed
333322

334323
except Exception:
335-
# Fallback to NumPy implementation for compatibility
336-
# Move to CPU if needed
337324
if bytemasks.is_cuda:
338325
bytemasks = bytemasks.cpu()
339326

340327
packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little")
341-
packed_bits_torch = torch.from_numpy(packed_bits_numpy)
342-
343-
return packed_bits_torch
328+
return torch.from_numpy(packed_bits_numpy)
344329

345330

346331
def unpack_bitmasks(

tests/test_sparse_optimization.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
import pytest
2+
import torch
3+
import numpy as np
4+
from compressed_tensors.utils.helpers import pack_bitmasks, unpack_bitmasks
5+
from compressed_tensors.compressors.sparse_compressors.sparse_24_bitmask import (
6+
get_24_bytemasks,
7+
sparse24_bitmask_compress,
8+
sparse24_bitmask_decompress,
9+
Sparse24BitMaskTensor,
10+
)
11+
12+
13+
class TestPackBitmasks:
14+
"""Test pack_bitmasks optimizations."""
15+
16+
def test_pack_bitmasks_correctness_cpu(self):
17+
"""Test PyTorch implementation matches NumPy on CPU."""
18+
test_shapes = [
19+
(1, 8),
20+
(1, 16),
21+
(10, 7),
22+
(10, 8),
23+
(10, 9),
24+
(100, 100),
25+
(128, 256),
26+
(1000, 1000),
27+
]
28+
29+
for shape in test_shapes:
30+
mask = torch.rand(shape) > 0.5
31+
32+
# PyTorch implementation
33+
packed_torch = pack_bitmasks(mask)
34+
35+
# NumPy reference
36+
packed_numpy = torch.from_numpy(
37+
np.packbits(mask.numpy(), axis=-1, bitorder="little")
38+
)
39+
40+
assert torch.equal(packed_torch, packed_numpy), \
41+
f"Mismatch for shape {shape}: PyTorch != NumPy"
42+
43+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
44+
def test_pack_bitmasks_gpu(self):
45+
"""Test GPU implementation produces correct results."""
46+
test_shapes = [(128, 256), (1024, 1024)]
47+
48+
for shape in test_shapes:
49+
mask = torch.rand(shape) > 0.5
50+
mask_gpu = mask.cuda()
51+
52+
# GPU implementation
53+
packed_gpu = pack_bitmasks(mask_gpu)
54+
assert packed_gpu.is_cuda, "Result should stay on GPU"
55+
56+
# CPU reference
57+
packed_cpu = pack_bitmasks(mask)
58+
59+
assert torch.equal(packed_gpu.cpu(), packed_cpu), \
60+
f"GPU result differs from CPU for shape {shape}"
61+
62+
def test_pack_unpack_roundtrip(self):
63+
"""Test pack/unpack roundtrip preserves data."""
64+
shapes = [(10, 16), (128, 256), (100, 999)]
65+
66+
for shape in shapes:
67+
mask = torch.rand(shape) > 0.5
68+
packed = pack_bitmasks(mask)
69+
unpacked = unpack_bitmasks(packed, list(shape))
70+
71+
assert torch.equal(mask, unpacked), \
72+
f"Roundtrip failed for shape {shape}"
73+
74+
def test_edge_cases(self):
75+
"""Test edge cases."""
76+
# Empty tensor
77+
empty = torch.empty(0, 0, dtype=torch.bool)
78+
packed = pack_bitmasks(empty)
79+
assert packed.shape == (0, 0)
80+
81+
# Single element
82+
single = torch.tensor([[True]])
83+
packed = pack_bitmasks(single)
84+
assert packed.shape == (1, 1)
85+
assert packed[0, 0] == 1
86+
87+
# All False
88+
all_false = torch.zeros(10, 16, dtype=torch.bool)
89+
packed = pack_bitmasks(all_false)
90+
assert torch.all(packed == 0)
91+
92+
# All True
93+
all_true = torch.ones(10, 16, dtype=torch.bool)
94+
packed = pack_bitmasks(all_true)
95+
expected = torch.full((10, 2), 255, dtype=torch.uint8)
96+
assert torch.equal(packed, expected)
97+
98+
99+
class TestSparse24Compression:
100+
"""Test sparse 2:4 compression optimizations."""
101+
102+
def test_compression_preserves_sparsity(self):
103+
"""Test that compression preserves 2:4 sparsity pattern."""
104+
tensor = torch.randn(128, 256)
105+
106+
# Get 2:4 mask
107+
mask = get_24_bytemasks(tensor)
108+
sparsity = (~mask).sum().item() / mask.numel()
109+
assert abs(sparsity - 0.5) < 0.01, "Should have ~50% sparsity"
110+
111+
# Compress and decompress
112+
compressed, bitmask = sparse24_bitmask_compress(tensor)
113+
decompressed = sparse24_bitmask_decompress(compressed, bitmask, tensor.shape)
114+
115+
# Check sparsity preserved
116+
decompressed_sparsity = (decompressed == 0).sum().item() / decompressed.numel()
117+
assert abs(decompressed_sparsity - 0.5) < 0.01, "Decompressed should maintain sparsity"
118+
119+
# Check values preserved
120+
assert torch.allclose(tensor[mask], decompressed[mask], rtol=1e-5)
121+
122+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
123+
def test_gpu_compression(self):
124+
"""Test compression works correctly on GPU."""
125+
tensor = torch.randn(256, 512).cuda()
126+
127+
# Compress on GPU
128+
compressed_tensor = Sparse24BitMaskTensor.from_dense(tensor)
129+
130+
# Check results moved to CPU for storage
131+
assert compressed_tensor.compressed.device.type == "cpu"
132+
assert compressed_tensor.bitmask.device.type == "cpu"
133+
134+
# Decompress and verify
135+
decompressed = compressed_tensor.decompress()
136+
mask = get_24_bytemasks(tensor.cpu())
137+
138+
assert torch.allclose(tensor.cpu()[mask], decompressed[mask], rtol=1e-5)
139+
140+
def test_various_dtypes(self):
141+
"""Test compression works with various dtypes."""
142+
dtypes = [torch.float32, torch.float16, torch.bfloat16]
143+
144+
for dtype in dtypes:
145+
if dtype == torch.bfloat16 and not torch.cuda.is_available():
146+
continue
147+
148+
tensor = torch.randn(64, 128, dtype=dtype)
149+
compressed_tensor = Sparse24BitMaskTensor.from_dense(tensor)
150+
decompressed = compressed_tensor.decompress()
151+
152+
mask = get_24_bytemasks(tensor)
153+
assert torch.allclose(
154+
tensor[mask].float(),
155+
decompressed[mask].float(),
156+
rtol=1e-3 if dtype == torch.float16 else 1e-5
157+
)
158+
159+
def test_deterministic_sparsity(self):
160+
"""Test that sparsity pattern is deterministic."""
161+
tensor = torch.randn(128, 256)
162+
163+
# Get mask multiple times
164+
mask1 = get_24_bytemasks(tensor)
165+
mask2 = get_24_bytemasks(tensor)
166+
mask3 = get_24_bytemasks(tensor)
167+
168+
assert torch.equal(mask1, mask2)
169+
assert torch.equal(mask2, mask3)
170+
171+
def test_topk_optimization(self):
172+
"""Test that topk with sorted=False produces correct results."""
173+
tensor = torch.randn(128, 256)
174+
175+
# Original implementation (sorted=True)
176+
reshaped = tensor.view(-1, 4)
177+
abs_vals = reshaped.abs()
178+
topk_sorted = abs_vals.topk(2, dim=1, largest=True, sorted=True).indices
179+
180+
# Optimized implementation (sorted=False)
181+
topk_unsorted = abs_vals.topk(2, dim=1, largest=True, sorted=False).indices
182+
183+
# Both should select the same elements (order doesn't matter)
184+
mask_sorted = torch.zeros_like(reshaped, dtype=torch.bool)
185+
mask_sorted.scatter_(1, topk_sorted, True)
186+
187+
mask_unsorted = torch.zeros_like(reshaped, dtype=torch.bool)
188+
mask_unsorted.scatter_(1, topk_unsorted, True)
189+
190+
assert torch.equal(mask_sorted, mask_unsorted)
191+
192+
193+
class TestPerformance:
194+
"""Performance regression tests."""
195+
196+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
197+
def test_gpu_faster_than_cpu_transfer(self):
198+
"""Test that GPU processing is faster than CPU transfer for large tensors."""
199+
import time
200+
201+
tensor = torch.randn(4096, 4096).cuda()
202+
203+
# Time GPU processing
204+
torch.cuda.synchronize()
205+
start = time.time()
206+
compressed, bitmask = sparse24_bitmask_compress(tensor)
207+
torch.cuda.synchronize()
208+
gpu_time = time.time() - start
209+
210+
# Time with CPU transfer
211+
torch.cuda.synchronize()
212+
start = time.time()
213+
tensor_cpu = tensor.cpu()
214+
compressed_cpu, bitmask_cpu = sparse24_bitmask_compress(tensor_cpu)
215+
cpu_time = time.time() - start
216+
217+
# GPU should be faster for large tensors
218+
assert gpu_time < cpu_time, \
219+
f"GPU ({gpu_time:.3f}s) should be faster than CPU transfer ({cpu_time:.3f}s)"
220+
221+
222+
if __name__ == "__main__":
223+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)