Skip to content

Commit 277d325

Browse files
rahul-tuliclaude
andcommitted
Refactor sparse optimization code following DRY and single responsibility principles
- Split pack_bitmasks into modular functions with single responsibilities: - _validate_bitmask_shape(): Input validation - _pack_bits_torch(): Core PyTorch packing logic - _pack_bits_numpy_fallback(): NumPy fallback - Refactored get_24_bytemasks with helper functions: - _validate_24_sparsity_tensor(): Tensor validation - _get_topk_mask(): Isolated mask generation algorithm - Improved error messages with actual tensor dimensions - Reduced test suite from 222 to 182 lines by removing redundancy - Organized tests into focused classes by functionality - All optimizations preserved, code is now more maintainable 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 893e189 commit 277d325

File tree

3 files changed

+182
-142
lines changed

3 files changed

+182
-142
lines changed

src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,38 @@ def sparse24_bitmask_decompress(
206206
return decompressed_tensor
207207

208208

209-
def get_24_bytemasks(tensor):
209+
def _validate_24_sparsity_tensor(tensor: torch.Tensor) -> None:
210+
"""
211+
Validate that tensor is suitable for 2:4 sparsity.
212+
213+
:param tensor: Input tensor to validate
214+
:raises ValueError: If tensor size is not a multiple of 4
215+
"""
216+
if tensor.numel() % 4 != 0:
217+
raise ValueError(
218+
f"Tensor size must be a multiple of 4 for 2:4 sparsity, "
219+
f"got {tensor.numel()} elements"
220+
)
221+
222+
223+
def _get_topk_mask(reshaped_tensor: torch.Tensor, k: int = 2) -> torch.Tensor:
224+
"""
225+
Get mask for top-k elements per group based on absolute values.
226+
227+
:param reshaped_tensor: Tensor reshaped into groups
228+
:param k: Number of elements to keep per group
229+
:return: Boolean mask tensor
230+
"""
231+
abs_tensor = reshaped_tensor.abs()
232+
# sorted=False provides performance improvement without affecting correctness
233+
topk_indices = abs_tensor.topk(k, dim=1, largest=True, sorted=False).indices
234+
235+
mask = torch.zeros_like(reshaped_tensor, dtype=torch.bool)
236+
mask.scatter_(1, topk_indices, True)
237+
return mask
238+
239+
240+
def get_24_bytemasks(tensor: torch.Tensor) -> torch.Tensor:
210241
"""
211242
Generate a 2:4 sparsity mask for the given tensor.
212243
@@ -222,22 +253,25 @@ def get_24_bytemasks(tensor):
222253
:raises ValueError: If the total number of elements in the tensor is not a
223254
multiple of 4.
224255
"""
256+
# Validate input
257+
_validate_24_sparsity_tensor(tensor)
258+
225259
original_dtype = tensor.dtype
260+
original_shape = tensor.shape
261+
262+
# Handle FP8 dtype by viewing as int8 for magnitude comparison
226263
if tensor.dtype == FP8_DTYPE:
227264
tensor = tensor.view(torch.int8)
228-
original_shape = tensor.shape
229-
num_elements = tensor.numel()
230-
231-
if num_elements % 4 != 0:
232-
raise ValueError("Tensor size must be a multiple of 4 for TWO_FOUR sparsity")
233-
265+
266+
# Reshape into groups of 4 and get top-2 mask
234267
reshaped_tensor = tensor.view(-1, 4)
235-
abs_tensor = reshaped_tensor.abs()
236-
topk_indices = abs_tensor.topk(2, dim=1, largest=True, sorted=False).indices
237-
mask = torch.zeros_like(reshaped_tensor, dtype=torch.bool)
238-
mask.scatter_(1, topk_indices, True)
268+
mask = _get_topk_mask(reshaped_tensor, k=2)
269+
270+
# Restore original shape
239271
mask = mask.view(original_shape)
240-
if tensor.dtype == torch.int8:
272+
273+
# Restore tensor dtype if it was changed
274+
if tensor.dtype == torch.int8 and original_dtype == FP8_DTYPE:
241275
tensor = tensor.view(original_dtype)
242-
276+
243277
return mask

src/compressed_tensors/utils/helpers.py

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -293,39 +293,85 @@ def combine_shards(shards, dim=0):
293293
return combined
294294

295295

296+
def _validate_bitmask_shape(bytemasks: torch.Tensor) -> None:
297+
"""
298+
Validates input tensor shape for bitmask packing.
299+
300+
:param bytemasks: Input tensor to validate
301+
:raises ValueError: If tensor is not 2D
302+
"""
303+
if len(bytemasks.shape) != 2:
304+
raise ValueError(
305+
f"pack_bitmasks expects a 2D tensor, got shape {bytemasks.shape}"
306+
)
307+
308+
309+
def _pack_bits_torch(bytemasks_uint8: torch.Tensor, rows: int, cols: int,
310+
device: torch.device) -> torch.Tensor:
311+
"""
312+
Pack bits using PyTorch operations.
313+
314+
:param bytemasks_uint8: Boolean mask converted to uint8
315+
:param rows: Number of rows in the mask
316+
:param cols: Number of columns in the mask
317+
:param device: Device to create the packed tensor on
318+
:return: Packed bitmask tensor
319+
"""
320+
packed_cols = (cols + 7) // 8
321+
packed = torch.zeros(rows, packed_cols, dtype=torch.uint8, device=device)
322+
323+
# Pack bits directly without padding
324+
for i in range(cols):
325+
packed[:, i // 8] |= bytemasks_uint8[:, i] << (i % 8)
326+
327+
return packed
328+
329+
330+
def _pack_bits_numpy_fallback(bytemasks: torch.Tensor) -> torch.Tensor:
331+
"""
332+
Fallback to NumPy implementation for compatibility.
333+
334+
:param bytemasks: Input boolean mask tensor
335+
:return: Packed bitmask tensor
336+
"""
337+
if bytemasks.is_cuda:
338+
bytemasks = bytemasks.cpu()
339+
340+
packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little")
341+
return torch.from_numpy(packed_bits_numpy)
342+
343+
296344
def pack_bitmasks(bytemasks: torch.Tensor) -> torch.Tensor:
297345
"""
298346
Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be
299-
compressed to R x ceil(C/8)
347+
compressed to R x ceil(C/8).
348+
349+
Supports both CPU and GPU tensors with automatic fallback to NumPy for compatibility.
300350
301-
:param bytemasks: mask tensor where each byte corresponds to a weight
302-
:return: mask tensor where each bit corresounds to a weight
351+
:param bytemasks: 2D boolean mask tensor where each element corresponds to a weight
352+
:return: Packed mask tensor where each bit corresponds to a weight
353+
:raises ValueError: If input tensor is not 2D
303354
"""
355+
# Validate input shape
356+
_validate_bitmask_shape(bytemasks)
357+
304358
try:
305359
device = bytemasks.device
306360
dtype = bytemasks.dtype
307361

362+
# Ensure boolean type
308363
if dtype != torch.bool:
309364
bytemasks = bytemasks.bool()
310365

311366
rows, cols = bytemasks.shape
312-
packed_cols = (cols + 7) // 8
313-
314367
bytemasks_uint8 = bytemasks.to(torch.uint8)
315-
packed = torch.zeros(rows, packed_cols, dtype=torch.uint8, device=device)
316368

317-
# Pack bits directly without padding
318-
for i in range(cols):
319-
packed[:, i // 8] |= bytemasks_uint8[:, i] << (i % 8)
320-
321-
return packed
369+
# Use PyTorch implementation
370+
return _pack_bits_torch(bytemasks_uint8, rows, cols, device)
322371

323372
except Exception:
324-
if bytemasks.is_cuda:
325-
bytemasks = bytemasks.cpu()
326-
327-
packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little")
328-
return torch.from_numpy(packed_bits_numpy)
373+
# Fallback to NumPy for compatibility
374+
return _pack_bits_numpy_fallback(bytemasks)
329375

330376

331377
def unpack_bitmasks(

0 commit comments

Comments
 (0)