Skip to content

Commit 70f93d3

Browse files
authored
[Tests] Expand tests for ModuleSparsificationInfo (#1631)
## Purpose ## Add unit tests to verify class `ModuleSparsificationInfo` return value ## Changes ## Add `test_sparse.py` to test `params_sparse_percent()` and `params_quantized_percent()`, providing indirect coverage for additional methods in class `ModuleSparsificationInfo`. For Issue #1224
1 parent 75d753d commit 70f93d3

File tree

1 file changed

+73
-0
lines changed

1 file changed

+73
-0
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import pytest
2+
import torch
3+
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme
4+
from torch.nn import Linear, Module, ReLU
5+
6+
from llmcompressor.pytorch.utils import ModuleSparsificationInfo
7+
8+
9+
class FakeQuantizedModel(Module):
10+
def __init__(self):
11+
super().__init__()
12+
self.fc1 = Linear(8, 16, bias=True) # Quantized
13+
self.fc2 = Linear(16, 4, bias=True) # Unquantized
14+
self.relu = ReLU()
15+
16+
self.fc1.quantization_scheme = QuantizationScheme(
17+
targets=["model.fc1"],
18+
weights=QuantizationArgs(
19+
precision=8,
20+
granularity="per_tensor",
21+
algorithm="gptq",
22+
blocksize=128,
23+
),
24+
)
25+
26+
27+
def test_module_quantization_info():
28+
model = FakeQuantizedModel()
29+
state = model.state_dict()
30+
31+
# Simulate quantized weights: replace float32 weights with int8
32+
state["fc1.weight"] = torch.randint(
33+
-128, 127, state["fc1.weight"].shape, dtype=torch.int8
34+
)
35+
36+
# Keep fc1.bias, fc2.weight, fc2.bias all as float32
37+
info = ModuleSparsificationInfo(model, state_dict=state)
38+
39+
# fc1 (quantized): 8 * 16 weights + 16 biases = 144 parameters.
40+
# fc2 (not quantized): 16 * 4 weights + 4 biases = 68 parameters.
41+
# Total parameters: 144 + 68 = 212.
42+
# Quantized percentage: (144 / 212) * 100 ≈ 67.92%.
43+
percent = info.params_quantized_percent
44+
45+
assert percent == pytest.approx(67.92, abs=1e-2)
46+
47+
48+
class FakeSparsedModel(Module):
49+
def __init__(self):
50+
super().__init__()
51+
self.linear_dense = Linear(10, 10, bias=True) # no sparsity
52+
self.linear_sparse = Linear(10, 10, bias=True) # sparse layer
53+
54+
# Inject sparsity into linear_sparse.weight (50% zeros)
55+
with torch.no_grad():
56+
weight = self.linear_sparse.weight
57+
weight.view(-1)[:50] = 0.0
58+
59+
60+
def test_module_sparsity_info():
61+
model = FakeSparsedModel()
62+
state = model.state_dict()
63+
64+
info = ModuleSparsificationInfo(model, state_dict=state)
65+
66+
# linear_dense: 10 * 10 weights + 10 biases = 110 parameters.
67+
# linear_sparse: 10 * 10 weights + 10 biases = 110 parameters.
68+
# Total parameters: 110 + 110 = 220
69+
# Number of sparse (zero) parameters: 50 (from linear_sparse.weight).
70+
# Sparsity percentage: (50 / 220) * 100 ≈ 22.73%.
71+
percent = info.params_sparse_percent
72+
73+
assert percent == pytest.approx(22.73, abs=1e-2)

0 commit comments

Comments
 (0)