Skip to content

Unexpected Input Scale Shape for Dynamic Per-Token Activation Quantization #394

@max410011

Description

@max410011

Problem

I'm using the llm-compressor example with W8A8 quantization settings, and I noticed that the dynamically computed activation input_scale shape seems incorrect when applying per-token quantization.

My input tensor shape is:
(1, 5, 2048) # (bs, seqlen, hidden_dim)

Since I'm using per-token quantization, I expect the input_scale shape to be:
(1, 5, 1) # One scale per token

But what I actually got is:
(1, 5, 2048) # One scale per token per hidden dim (not expected)

After tracing the code, I found in src/compressed_tensors/quantization/utils/helpers.py#L169-L199

if args.strategy == QuantizationStrategy.TOKEN:
    dim = {1, 2}
    reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)

...

if not reduce_dims:
    min_val, max_val = torch.aminmax(value)
else:
    min_val = torch.amin(value, dim=reduce_dims, keepdims=keep_dims)
    max_val = torch.amax(value, dim=reduce_dims, keepdims=keep_dims)

This causes the batch dimension (dim=0) to be reduced, but keeps hidden_dim (dim=2), which is the opposite of what we want.
The correct behavior for per-token quantization should reduce only the hidden_dim (i.e., dim=2) and keep batch and sequence length dimensions.

Fix

I have created a PR #393 to fix the bug.

```diff if args.strategy == QuantizationStrategy.TOKEN: - dim = {1, 2} + dim = {0, 1} reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim) ```
if args.strategy == QuantizationStrategy.TOKEN:
-   dim = {1, 2}
+   dim = {0, 1}
    reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)

This ensures the input_scale shape becomes (bs, seqlen, 1), which aligns with true per-token quantization.

Reproduce the Bug

The example I ran is from this link. You can also use the code below.

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.utils import dispatch_for_generation

# Step 1: Select model and load it.
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Step 2: # Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 128
MAX_SEQUENCE_LENGTH = 2048

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)

def preprocess(example):
    return {
        "text": tokenizer.apply_chat_template(
            example["messages"],
            tokenize=False,
        )
    }

ds = ds.map(preprocess)

# Tokenize inputs.
def tokenize(sample):
    return tokenizer(
        sample["text"],
        padding=False,
        max_length=MAX_SEQUENCE_LENGTH,
        truncation=True,
        add_special_tokens=False,
    )

ds = ds.map(tokenize, remove_columns=ds.column_names)

# Step 3: Apply Quantization

# Configure algorithms. In this case, we:
#   * quantize the weights to int8 (static per channel)
#   * quantize the activations to int8 (dynamic per token)
recipe = [
    QuantizationModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),
]

# Apply algorithms and save to output_dir
oneshot(
    model=model,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

You can print the input and scale shape in src/compressed_tensors/quantization/lifecycle/forward.py#L373-L377

if args.dynamic in (True, DynamicType.LOCAL):
    # dynamic quantization - determine the scale/zp on the fly
    scale, zero_point = compute_dynamic_scales_and_zp(
        value=value, args=args, module=module, global_scale=global_scale
    )
    print(f"Input shape: {value.shape}")
    print(f"Scale shape: {scale.shape}")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions