Skip to content

Commit 7823af0

Browse files
dsikkarahul-tuli
andauthored
Fix case when using weight_packed, not weight (#278)
* fix CL bug * wrap in param * Add: Tests to verify CompressedLinear Usage, and it's forward pass invocation --------- Co-authored-by: Rahul Tuli <rahul@neuralmagic.com>
1 parent 95ba68d commit 7823af0

File tree

3 files changed

+115
-2
lines changed

3 files changed

+115
-2
lines changed

src/compressed_tensors/linear/compressed_linear.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,10 @@ def forward(self, input: Tensor) -> Tensor:
9999
Decompresses the weight, then runs the wrapped forward pass
100100
"""
101101
if self.quantization_status == QuantizationStatus.COMPRESSED:
102-
decompressed_weight = self.compressor.decompress_module(self)
103-
self.weight.data = decompressed_weight
102+
weight_data = self.compressor.decompress_module(self)
103+
param = Parameter(weight_data, requires_grad=False)
104+
register_offload_parameter(self, "weight", param)
105+
104106
self.quantization_status = QuantizationStatus.FROZEN
105107

106108
return linear(input, self.weight, self.bias)

tests/test_linear/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
import torch
17+
from compressed_tensors.linear.compressed_linear import CompressedLinear
18+
from transformers import AutoModelForCausalLM, AutoTokenizer
19+
20+
21+
def models_with_linear_quantized():
22+
return [
23+
# weights packed
24+
"nm-testing/llama2.c-stories110M-gsm8k-recipe_w4a16_actorder_weight-compressed",
25+
# weights not packed
26+
"nm-testing/llama2.c-stories110M-gsm8k-fp8_dynamic-compressed",
27+
]
28+
29+
30+
@pytest.mark.parametrize("model_stub", models_with_linear_quantized())
31+
def test_model_forward_pass(model_stub):
32+
"""
33+
Test that AutoModelForCausalLM can process tokenized inputs and generate output.
34+
"""
35+
# Load model
36+
model = AutoModelForCausalLM.from_pretrained(
37+
model_stub, torch_dtype=torch.float16, device_map="auto"
38+
)
39+
40+
# Load tokenizer
41+
tokenizer = AutoTokenizer.from_pretrained(model_stub)
42+
43+
# Define sample input
44+
sample_inputs = [
45+
"I love quantization because",
46+
"What is the capital of France?",
47+
"def fibonacci(n):",
48+
]
49+
50+
# Move inputs to the correct device
51+
device = next(model.parameters()).device
52+
inputs = tokenizer(sample_inputs, return_tensors="pt", padding=True).to(device)
53+
54+
# Run model inference (forward pass)
55+
outputs = model.generate(**inputs, max_length=50)
56+
57+
# Ensure output is not empty
58+
assert outputs is not None, "Model forward pass failed, no output generated."
59+
60+
61+
@pytest.mark.parametrize("model_stub", models_with_linear_quantized())
62+
def test_compressed_linear_from_linear_usage(monkeypatch, model_stub):
63+
"""
64+
Test that CompressedLinear.from_linear is used for creating
65+
CompressedLinear instances.
66+
"""
67+
call_count = 0
68+
69+
original_from_linear = CompressedLinear.from_linear
70+
71+
def fake_from_linear(*args, **kwargs):
72+
nonlocal call_count
73+
call_count += 1
74+
return original_from_linear(*args, **kwargs)
75+
76+
# Replace the original from_linear with our fake to count its invocations
77+
monkeypatch.setattr(CompressedLinear, "from_linear", fake_from_linear)
78+
79+
# Load model to trigger the creation of CompressedLinear instances
80+
model = AutoModelForCausalLM.from_pretrained(
81+
model_stub, torch_dtype="auto", device_map="auto"
82+
)
83+
84+
# Known quantized layers that should be
85+
# instances of CompressedLinear
86+
# (This is not an exhaustive list)
87+
quantized_layers = {"q_proj", "k_proj", "v_proj"}
88+
89+
# Check that the expected layers are instances of CompressedLinear
90+
for layer_name, module in model.named_modules():
91+
if any(layer in layer_name for layer in quantized_layers):
92+
assert isinstance(
93+
module, CompressedLinear
94+
), f"{layer_name} should be an instance of CompressedLinear"
95+
f"but got {type(module).__name__}"
96+
97+
assert call_count > 0, "`CompressedLinear.from_linear` was not used during the "
98+
"creation of CompressedLinear instances."

0 commit comments

Comments
 (0)