Skip to content

Commit 3262d85

Browse files
authored
Make observer args configurable (#1492)
SUMMARY: There were two ways to pass in arguments: 1. Initialize when calling observer. See example usage [here](https://github.com/vllm-project/llm-compressor/blob/030a5bee05c7e319350b6cab204a09f47d0ee552/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py#L100). ```python observer = Observer.load_from_registry( quant_args.observer, quantization_args=quant_args, averaging_constant=1.0, # ignore moving average ) ``` Move the unpacking logic to `calibration.py`, extracting kwargs directly and then pass in `load_from_registry` 2. Defined in recipe and parsed in quantization args. This overrides other sources, except `averaging_constant` if being ignored. Example usage: ```yaml config_groups: group_0: weights: {num_bits: 8, type: int, symmetric: true, strategy: tensor, observer: mse, observer_kwargs: {patience: 5}} input_activations: {num_bits: 8, type: int, symmetric: true, strategy: tensor} targets: [Linear] ``` TEST PLAN: Tested locally. --------- Signed-off-by: shanjiaz <zsjwpianpian@gmail.com>
1 parent c052d2c commit 3262d85

File tree

4 files changed

+114
-4
lines changed

4 files changed

+114
-4
lines changed

docs/observers.md

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Observers Overview
2+
3+
An `Observer` in `llm-compressor` is a utility class responsible for analyzing tensors (e.g., weights, activations) and producing quantization parameters such as `scale` and `zero_point`. These observers are used by quantization modifiers to compute the statistics necessary for transforming tensors into lower precision formats.
4+
5+
Observers are designed to be flexible and support a variety of quantization strategies, including per-tensor, per-group, per-channel, and per-token quantization.
6+
7+
## Base Class
8+
9+
### [Observer](../src/llmcompressor/observers/base.py)
10+
Base class for all observers. Subclasses must implement the `calculate_qparams` method to define how quantization parameters are computed.
11+
12+
The base class handles:
13+
- Group-wise scale/zero_point computation
14+
- Token-wise and channel-wise quantization logic
15+
- Optional support for `g_idx` (group index mappings)
16+
- Recording observed tokens for logging and analysis
17+
- Resetting internal state during lifecycle transitions
18+
19+
This class is not used directly but provides the scaffolding for all custom observers.
20+
21+
## Implemented Observers
22+
23+
### [MinMax](../src/llmcompressor/observers/min_max.py)
24+
Computes `scale` and `zero_point` by tracking the minimum and maximum of the observed tensor. This is the simplest and most common observer. Works well for symmetric and asymmetric quantization.
25+
26+
Best used for:
27+
- Int8 or Int4 symmetric quantization
28+
- Channel-wise or group-wise strategies
29+
30+
### [MSE](../src/llmcompressor/observers/mse.py)
31+
Computes quantization parameters by minimizing the Mean Squared Error (MSE) between the original and quantized tensor. Optionally maintains a moving average of min/max values for smoother convergence.
32+
33+
Best used when:
34+
- Calibration accuracy is critical
35+
- Quantization error needs to be tightly controlled
36+
37+
## Quantization Strategies
38+
39+
Observers support multiple quantization strategies via the `QuantizationArgs.strategy` field:
40+
41+
- `TENSOR`: Global scale and zero_point across entire tensor.
42+
- `GROUP`, `TENSOR_GROUP`: Slice tensor into equal-sized groups along columns.
43+
- `CHANNEL`: Per-channel quantization (e.g., across output dimensions).
44+
- `TOKEN`: Quantize activations along token or sequence dimensions.
45+
- `BLOCK`: *(Not yet implemented)* Placeholder for block-wise quantization.
46+
47+
## Observer Configuration Parameters
48+
49+
Observers can be configured with optional keyword arguments that control their behavior. These are passed through the `QuantizationArgs.observer_kwargs` dictionary and parsed internally when the observer is initialized.
50+
51+
Below are the supported configuration parameters and their meanings:
52+
53+
| Argument | Default Value |
54+
|---------------------|---------------|
55+
| `maxshrink` | `0.20` |
56+
| `patience` | `5` |
57+
| `averaging_constant`| `0.01` |
58+
| `grid` | `100.0` |
59+
| `norm` | `2.0` |
60+
61+
## Example Usage
62+
63+
```python
64+
from llmcompressor.observers import Observer
65+
from compressed_tensors.quantization.quant_args import QuantizationArgs
66+
67+
args = QuantizationArgs(num_bits=4, strategy="group", group_size=128)
68+
observer = Observer.load_from_registry("minmax", quantization_args=args)
69+
70+
x = torch.randn(64, 512)
71+
scale, zero_point = observer(x)
72+
```
73+
74+
## Example yaml Usage
75+
``` yaml
76+
quantization_stage:
77+
quantization_modifiers:
78+
GPTQModifier:
79+
weights:
80+
observer: mse
81+
observer_kwargs:
82+
maxshrink: 0.1
83+
patience: 10
84+
averaging_constant: 0.05
85+
grid: 128.0
86+
norm: 2.0
87+
num_bits: 4
88+
type: int
89+
symmetric: true
90+
strategy: channel
91+
targets:
92+
- Linear
93+
```

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
from llmcompressor.observers import Observer
1919
from llmcompressor.utils.helpers import getattr_chain
2020

21+
DEFAULT_MAXSHRINK = 0.20
22+
DEFAULT_PATIENCE = 5
23+
DEFAULT_AVERAGING_CONSTANT = 0.01
24+
DEFAULT_GRID = 100.0
25+
DEFAULT_NORM = 2.4
26+
2127
__all__ = [
2228
"initialize_observer",
2329
"update_weight_zp_scale",
@@ -60,9 +66,18 @@ def initialize_observer(
6066
False,
6167
DynamicType.LOCAL,
6268
):
69+
observer_kwargs = quantization_args.observer_kwargs or {}
6370
observer = Observer.load_from_registry(
6471
quantization_args.observer,
6572
quantization_args=quantization_args,
73+
averaging_constant=observer_kwargs.get(
74+
"averaging_constant", DEFAULT_AVERAGING_CONSTANT
75+
),
76+
# used by mse observer only, will be ignored by minmax observer
77+
maxshrink=observer_kwargs.get("maxshrink", DEFAULT_MAXSHRINK),
78+
patience=observer_kwargs.get("patience", DEFAULT_PATIENCE),
79+
grid=observer_kwargs.get("grid", DEFAULT_GRID),
80+
norm=observer_kwargs.get("norm", DEFAULT_NORM)
6681
)
6782
module.register_module(f"{base_name}_observer", observer)
6883

src/llmcompressor/observers/min_max.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(
2222
self,
2323
quantization_args: QuantizationArgs,
2424
averaging_constant: float = 0.01,
25+
**kwargs,
2526
):
2627
super().__init__(quantization_args=quantization_args)
2728

src/llmcompressor/observers/mse.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,19 @@ class MovingAverageMSEObserver(Observer):
2020
def __init__(
2121
self,
2222
quantization_args: QuantizationArgs,
23+
maxshrink: float = 0.2,
24+
patience: int = 5,
2325
averaging_constant: float = 0.01,
2426
grid: float = 100.0,
2527
norm: float = 2.4,
28+
**kwargs,
2629
):
2730
super().__init__(quantization_args=quantization_args)
2831

29-
kwargs = quantization_args.observer_kwargs or {}
30-
self.maxshrink = kwargs.get("maxshrink", 0.20)
31-
self.patience = kwargs.get("patience", 5)
32-
3332
self.min_val = {}
3433
self.max_val = {}
34+
self.maxshrink = maxshrink
35+
self.patience = patience
3536
self.averaging_constant = averaging_constant
3637
self.grid = grid
3738
self.norm = norm

0 commit comments

Comments
 (0)