1
+ import pytest
2
+ import torch
3
+ import numpy as np
4
+ from compressed_tensors .utils .helpers import pack_bitmasks , unpack_bitmasks
5
+ from compressed_tensors .compressors .sparse_compressors .sparse_24_bitmask import (
6
+ get_24_bytemasks ,
7
+ sparse24_bitmask_compress ,
8
+ sparse24_bitmask_decompress ,
9
+ Sparse24BitMaskTensor ,
10
+ )
11
+
12
+
13
+ class TestPackBitmasks :
14
+ """Test pack_bitmasks optimizations."""
15
+
16
+ def test_pack_bitmasks_correctness_cpu (self ):
17
+ """Test PyTorch implementation matches NumPy on CPU."""
18
+ test_shapes = [
19
+ (1 , 8 ),
20
+ (1 , 16 ),
21
+ (10 , 7 ),
22
+ (10 , 8 ),
23
+ (10 , 9 ),
24
+ (100 , 100 ),
25
+ (128 , 256 ),
26
+ (1000 , 1000 ),
27
+ ]
28
+
29
+ for shape in test_shapes :
30
+ mask = torch .rand (shape ) > 0.5
31
+
32
+ # PyTorch implementation
33
+ packed_torch = pack_bitmasks (mask )
34
+
35
+ # NumPy reference
36
+ packed_numpy = torch .from_numpy (
37
+ np .packbits (mask .numpy (), axis = - 1 , bitorder = "little" )
38
+ )
39
+
40
+ assert torch .equal (packed_torch , packed_numpy ), \
41
+ f"Mismatch for shape { shape } : PyTorch != NumPy"
42
+
43
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
44
+ def test_pack_bitmasks_gpu (self ):
45
+ """Test GPU implementation produces correct results."""
46
+ test_shapes = [(128 , 256 ), (1024 , 1024 )]
47
+
48
+ for shape in test_shapes :
49
+ mask = torch .rand (shape ) > 0.5
50
+ mask_gpu = mask .cuda ()
51
+
52
+ # GPU implementation
53
+ packed_gpu = pack_bitmasks (mask_gpu )
54
+ assert packed_gpu .is_cuda , "Result should stay on GPU"
55
+
56
+ # CPU reference
57
+ packed_cpu = pack_bitmasks (mask )
58
+
59
+ assert torch .equal (packed_gpu .cpu (), packed_cpu ), \
60
+ f"GPU result differs from CPU for shape { shape } "
61
+
62
+ def test_pack_unpack_roundtrip (self ):
63
+ """Test pack/unpack roundtrip preserves data."""
64
+ shapes = [(10 , 16 ), (128 , 256 ), (100 , 999 )]
65
+
66
+ for shape in shapes :
67
+ mask = torch .rand (shape ) > 0.5
68
+ packed = pack_bitmasks (mask )
69
+ unpacked = unpack_bitmasks (packed , list (shape ))
70
+
71
+ assert torch .equal (mask , unpacked ), \
72
+ f"Roundtrip failed for shape { shape } "
73
+
74
+ def test_edge_cases (self ):
75
+ """Test edge cases."""
76
+ # Empty tensor
77
+ empty = torch .empty (0 , 0 , dtype = torch .bool )
78
+ packed = pack_bitmasks (empty )
79
+ assert packed .shape == (0 , 0 )
80
+
81
+ # Single element
82
+ single = torch .tensor ([[True ]])
83
+ packed = pack_bitmasks (single )
84
+ assert packed .shape == (1 , 1 )
85
+ assert packed [0 , 0 ] == 1
86
+
87
+ # All False
88
+ all_false = torch .zeros (10 , 16 , dtype = torch .bool )
89
+ packed = pack_bitmasks (all_false )
90
+ assert torch .all (packed == 0 )
91
+
92
+ # All True
93
+ all_true = torch .ones (10 , 16 , dtype = torch .bool )
94
+ packed = pack_bitmasks (all_true )
95
+ expected = torch .full ((10 , 2 ), 255 , dtype = torch .uint8 )
96
+ assert torch .equal (packed , expected )
97
+
98
+
99
+ class TestSparse24Compression :
100
+ """Test sparse 2:4 compression optimizations."""
101
+
102
+ def test_compression_preserves_sparsity (self ):
103
+ """Test that compression preserves 2:4 sparsity pattern."""
104
+ tensor = torch .randn (128 , 256 )
105
+
106
+ # Get 2:4 mask
107
+ mask = get_24_bytemasks (tensor )
108
+ sparsity = (~ mask ).sum ().item () / mask .numel ()
109
+ assert abs (sparsity - 0.5 ) < 0.01 , "Should have ~50% sparsity"
110
+
111
+ # Compress and decompress
112
+ compressed , bitmask = sparse24_bitmask_compress (tensor )
113
+ decompressed = sparse24_bitmask_decompress (compressed , bitmask , tensor .shape )
114
+
115
+ # Check sparsity preserved
116
+ decompressed_sparsity = (decompressed == 0 ).sum ().item () / decompressed .numel ()
117
+ assert abs (decompressed_sparsity - 0.5 ) < 0.01 , "Decompressed should maintain sparsity"
118
+
119
+ # Check values preserved
120
+ assert torch .allclose (tensor [mask ], decompressed [mask ], rtol = 1e-5 )
121
+
122
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
123
+ def test_gpu_compression (self ):
124
+ """Test compression works correctly on GPU."""
125
+ tensor = torch .randn (256 , 512 ).cuda ()
126
+
127
+ # Compress on GPU
128
+ compressed_tensor = Sparse24BitMaskTensor .from_dense (tensor )
129
+
130
+ # Check results moved to CPU for storage
131
+ assert compressed_tensor .compressed .device .type == "cpu"
132
+ assert compressed_tensor .bitmask .device .type == "cpu"
133
+
134
+ # Decompress and verify
135
+ decompressed = compressed_tensor .decompress ()
136
+ mask = get_24_bytemasks (tensor .cpu ())
137
+
138
+ assert torch .allclose (tensor .cpu ()[mask ], decompressed [mask ], rtol = 1e-5 )
139
+
140
+ def test_various_dtypes (self ):
141
+ """Test compression works with various dtypes."""
142
+ dtypes = [torch .float32 , torch .float16 , torch .bfloat16 ]
143
+
144
+ for dtype in dtypes :
145
+ if dtype == torch .bfloat16 and not torch .cuda .is_available ():
146
+ continue
147
+
148
+ tensor = torch .randn (64 , 128 , dtype = dtype )
149
+ compressed_tensor = Sparse24BitMaskTensor .from_dense (tensor )
150
+ decompressed = compressed_tensor .decompress ()
151
+
152
+ mask = get_24_bytemasks (tensor )
153
+ assert torch .allclose (
154
+ tensor [mask ].float (),
155
+ decompressed [mask ].float (),
156
+ rtol = 1e-3 if dtype == torch .float16 else 1e-5
157
+ )
158
+
159
+ def test_deterministic_sparsity (self ):
160
+ """Test that sparsity pattern is deterministic."""
161
+ tensor = torch .randn (128 , 256 )
162
+
163
+ # Get mask multiple times
164
+ mask1 = get_24_bytemasks (tensor )
165
+ mask2 = get_24_bytemasks (tensor )
166
+ mask3 = get_24_bytemasks (tensor )
167
+
168
+ assert torch .equal (mask1 , mask2 )
169
+ assert torch .equal (mask2 , mask3 )
170
+
171
+ def test_topk_optimization (self ):
172
+ """Test that topk with sorted=False produces correct results."""
173
+ tensor = torch .randn (128 , 256 )
174
+
175
+ # Original implementation (sorted=True)
176
+ reshaped = tensor .view (- 1 , 4 )
177
+ abs_vals = reshaped .abs ()
178
+ topk_sorted = abs_vals .topk (2 , dim = 1 , largest = True , sorted = True ).indices
179
+
180
+ # Optimized implementation (sorted=False)
181
+ topk_unsorted = abs_vals .topk (2 , dim = 1 , largest = True , sorted = False ).indices
182
+
183
+ # Both should select the same elements (order doesn't matter)
184
+ mask_sorted = torch .zeros_like (reshaped , dtype = torch .bool )
185
+ mask_sorted .scatter_ (1 , topk_sorted , True )
186
+
187
+ mask_unsorted = torch .zeros_like (reshaped , dtype = torch .bool )
188
+ mask_unsorted .scatter_ (1 , topk_unsorted , True )
189
+
190
+ assert torch .equal (mask_sorted , mask_unsorted )
191
+
192
+
193
+ class TestPerformance :
194
+ """Performance regression tests."""
195
+
196
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
197
+ def test_gpu_faster_than_cpu_transfer (self ):
198
+ """Test that GPU processing is faster than CPU transfer for large tensors."""
199
+ import time
200
+
201
+ tensor = torch .randn (4096 , 4096 ).cuda ()
202
+
203
+ # Time GPU processing
204
+ torch .cuda .synchronize ()
205
+ start = time .time ()
206
+ compressed , bitmask = sparse24_bitmask_compress (tensor )
207
+ torch .cuda .synchronize ()
208
+ gpu_time = time .time () - start
209
+
210
+ # Time with CPU transfer
211
+ torch .cuda .synchronize ()
212
+ start = time .time ()
213
+ tensor_cpu = tensor .cpu ()
214
+ compressed_cpu , bitmask_cpu = sparse24_bitmask_compress (tensor_cpu )
215
+ cpu_time = time .time () - start
216
+
217
+ # GPU should be faster for large tensors
218
+ assert gpu_time < cpu_time , \
219
+ f"GPU ({ gpu_time :.3f} s) should be faster than CPU transfer ({ cpu_time :.3f} s)"
220
+
221
+
222
+ if __name__ == "__main__" :
223
+ pytest .main ([__file__ , "-v" ])
0 commit comments