-
Notifications
You must be signed in to change notification settings - Fork 18
Description
Problem
I'm using the llm-compressor example with W8A8 quantization settings, and I noticed that the dynamically computed activation input_scale shape seems incorrect when applying per-token quantization.
My input tensor shape is:
(1, 5, 2048) # (bs, seqlen, hidden_dim)
Since I'm using per-token quantization, I expect the input_scale shape to be:
(1, 5, 1) # One scale per token
But what I actually got is:
(1, 5, 2048) # One scale per token per hidden dim (not expected)
After tracing the code, I found in src/compressed_tensors/quantization/utils/helpers.py#L169-L199
if args.strategy == QuantizationStrategy.TOKEN:
dim = {1, 2}
reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
...
if not reduce_dims:
min_val, max_val = torch.aminmax(value)
else:
min_val = torch.amin(value, dim=reduce_dims, keepdims=keep_dims)
max_val = torch.amax(value, dim=reduce_dims, keepdims=keep_dims)
This causes the batch dimension (dim=0) to be reduced, but keeps hidden_dim (dim=2), which is the opposite of what we want.
The correct behavior for per-token quantization should reduce only the hidden_dim (i.e., dim=2) and keep batch and sequence length dimensions.
Fix
I have created a PR #393 to fix the bug.
```diff if args.strategy == QuantizationStrategy.TOKEN: - dim = {1, 2} + dim = {0, 1} reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim) ```
if args.strategy == QuantizationStrategy.TOKEN:
- dim = {1, 2}
+ dim = {0, 1}
reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
This ensures the input_scale shape becomes (bs, seqlen, 1), which aligns with true per-token quantization.
Reproduce the Bug
The example I ran is from this link. You can also use the code below.
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.utils import dispatch_for_generation
# Step 1: Select model and load it.
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Step 2: # Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 128
MAX_SEQUENCE_LENGTH = 2048
# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)
def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}
ds = ds.map(preprocess)
# Tokenize inputs.
def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)
ds = ds.map(tokenize, remove_columns=ds.column_names)
# Step 3: Apply Quantization
# Configure algorithms. In this case, we:
# * quantize the weights to int8 (static per channel)
# * quantize the activations to int8 (dynamic per token)
recipe = [
QuantizationModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),
]
# Apply algorithms and save to output_dir
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)
# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
You can print the input and scale shape in src/compressed_tensors/quantization/lifecycle/forward.py#L373-L377
if args.dynamic in (True, DynamicType.LOCAL):
# dynamic quantization - determine the scale/zp on the fly
scale, zero_point = compute_dynamic_scales_and_zp(
value=value, args=args, module=module, global_scale=global_scale
)
print(f"Input shape: {value.shape}")
print(f"Scale shape: {scale.shape}")