Skip to content

Commit 51c7968

Browse files
authored
Merge branch 'main' into provide_moe_calibration_mode
2 parents db57734 + 70f93d3 commit 51c7968

File tree

4 files changed

+78
-4
lines changed

4 files changed

+78
-4
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def localversion_func(version: ScmVersion) -> str:
119119
"tqdm>=4.0.0",
120120
# torch 1.10 and 1.11 do not support quantized onnx export
121121
"torch>=1.7.0,!=1.10,!=1.11",
122-
"transformers>4.0,<4.53.0",
122+
"transformers>4.0",
123123
"datasets",
124124
"accelerate>=0.20.3,!=1.1.0",
125125
"pynvml",

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,9 @@ def oneshot(
307307
"""
308308

309309
# pass all args directly into Oneshot
310-
local_args = locals()
311-
local_args.pop("kwargs")
310+
local_args = {
311+
k: v for k, v in locals().items() if k not in ("local_args", "kwargs")
312+
}
312313
one_shot = Oneshot(**local_args, **kwargs)
313314
one_shot()
314315

src/llmcompressor/entrypoints/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def initialize_processor_from_path(
241241
)
242242

243243
except ValueError as exception:
244-
if "trust_remote_code=True" in exception.value:
244+
if any("trust_remote_code=True" in arg for arg in exception.args):
245245
raise ValueError(
246246
f"The repository for {processor_src} contains custom code which must "
247247
"be executed to correctly load the tokenizer/processor. You can "
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import pytest
2+
import torch
3+
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme
4+
from torch.nn import Linear, Module, ReLU
5+
6+
from llmcompressor.pytorch.utils import ModuleSparsificationInfo
7+
8+
9+
class FakeQuantizedModel(Module):
10+
def __init__(self):
11+
super().__init__()
12+
self.fc1 = Linear(8, 16, bias=True) # Quantized
13+
self.fc2 = Linear(16, 4, bias=True) # Unquantized
14+
self.relu = ReLU()
15+
16+
self.fc1.quantization_scheme = QuantizationScheme(
17+
targets=["model.fc1"],
18+
weights=QuantizationArgs(
19+
precision=8,
20+
granularity="per_tensor",
21+
algorithm="gptq",
22+
blocksize=128,
23+
),
24+
)
25+
26+
27+
def test_module_quantization_info():
28+
model = FakeQuantizedModel()
29+
state = model.state_dict()
30+
31+
# Simulate quantized weights: replace float32 weights with int8
32+
state["fc1.weight"] = torch.randint(
33+
-128, 127, state["fc1.weight"].shape, dtype=torch.int8
34+
)
35+
36+
# Keep fc1.bias, fc2.weight, fc2.bias all as float32
37+
info = ModuleSparsificationInfo(model, state_dict=state)
38+
39+
# fc1 (quantized): 8 * 16 weights + 16 biases = 144 parameters.
40+
# fc2 (not quantized): 16 * 4 weights + 4 biases = 68 parameters.
41+
# Total parameters: 144 + 68 = 212.
42+
# Quantized percentage: (144 / 212) * 100 ≈ 67.92%.
43+
percent = info.params_quantized_percent
44+
45+
assert percent == pytest.approx(67.92, abs=1e-2)
46+
47+
48+
class FakeSparsedModel(Module):
49+
def __init__(self):
50+
super().__init__()
51+
self.linear_dense = Linear(10, 10, bias=True) # no sparsity
52+
self.linear_sparse = Linear(10, 10, bias=True) # sparse layer
53+
54+
# Inject sparsity into linear_sparse.weight (50% zeros)
55+
with torch.no_grad():
56+
weight = self.linear_sparse.weight
57+
weight.view(-1)[:50] = 0.0
58+
59+
60+
def test_module_sparsity_info():
61+
model = FakeSparsedModel()
62+
state = model.state_dict()
63+
64+
info = ModuleSparsificationInfo(model, state_dict=state)
65+
66+
# linear_dense: 10 * 10 weights + 10 biases = 110 parameters.
67+
# linear_sparse: 10 * 10 weights + 10 biases = 110 parameters.
68+
# Total parameters: 110 + 110 = 220
69+
# Number of sparse (zero) parameters: 50 (from linear_sparse.weight).
70+
# Sparsity percentage: (50 / 220) * 100 ≈ 22.73%.
71+
percent = info.params_sparse_percent
72+
73+
assert percent == pytest.approx(22.73, abs=1e-2)

0 commit comments

Comments
 (0)