Skip to content

Support DeepSeekV3-style block FP8 quantization #1607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/schemes.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ PTQ is performed to reduce the precision of quantizable weights (e.g., linear la
- Useful for speed ups in high QPS regimes or offline serving on vLLM.
- Recommended for NVIDIA GPUs with compute capability >=9.0 (Hopper and Blackwell).

### [W8A8-FP8_BLOCK](../examples/quantization_w8a8_fp8/fp8_block_example.py)
- Uses block-wise quantization to compress weights to FP8 in (commonly 128×128 tiles), and dynamic per-token-group (128) quantization for activations. Does not require calibration dataset. Activation quantization is carried out during inference on vLLM.

## Sparsification
Sparsification reduces model complexity by pruning selected weight values to zero while retaining essential weights in a subset of parameters. Supported formats include:

Expand Down
35 changes: 35 additions & 0 deletions examples/quantization_w8a8_fp8/fp8_block_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier

MODEL_ID = "Qwen/Qwen3-0.6B"

# Load model.
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Configure the quantization algorithm and scheme.
# In this case, we:
# * quantize the weights to fp8 with per channel via ptq
# * quantize the activations to fp8 with dynamic per token
recipe = QuantizationModifier(
targets="Linear", scheme="FP8_BLOCK", ignore=["lm_head"]
)

# Apply quantization.
oneshot(model=model, recipe=recipe)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=20)
print(tokenizer.decode(output[0]))
print("==========================================")

# Save to disk in compressed-tensors format.
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-BLOCK"
model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)
21 changes: 19 additions & 2 deletions src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,25 @@ def call_observer(
updated_scale, updated_zero_point = observer(
value, g_idx=g_idx, global_scale=global_scale
)
update_parameter_data(module, updated_scale, f"{base_name}_scale")
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
# register or update scale & zero_point parameters (supports block shapes)
scale_name = f"{base_name}_scale"
zp_name = f"{base_name}_zero_point"
if not hasattr(module, scale_name) or getattr(module, scale_name).shape != updated_scale.shape:
if hasattr(module, scale_name):
delattr(module, scale_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kylesayrs do we need this and L140 to be delete_offload_parameter?

module.register_parameter(
scale_name, torch.nn.Parameter(updated_scale.clone())
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kylesayrs do we need these (this and L141-143) to use register_offload_parameter instead?

else:
update_parameter_data(module, updated_scale, scale_name)
if not hasattr(module, zp_name) or getattr(module, zp_name).shape != updated_zero_point.shape:
if hasattr(module, zp_name):
delattr(module, zp_name)
module.register_parameter(
zp_name, torch.nn.Parameter(updated_zero_point.clone())
)
else:
update_parameter_data(module, updated_zero_point, zp_name)


def update_weight_global_scale(module: Module):
Expand Down
30 changes: 24 additions & 6 deletions src/llmcompressor/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,30 @@ def get_qparams(
)

elif self.quantization_args.strategy == QuantizationStrategy.BLOCK:
# TODO (#1475) add support for block-wise quantization
raise NotImplementedError(
"Block-wise quantization is not yet supported, "
"consider group-wise quantization instead. More info at "
"https://github.com/vllm-project/llm-compressor/issues/1475"
)
# Block-wise quantization: one scale/zero_point per block of shape [block_rows, block_cols]
rows, cols = observed.shape[:2]
bs = self.quantization_args.block_structure
if not (isinstance(bs, (list, tuple)) and len(bs) == 2 and all(isinstance(x, int) for x in bs)):
raise ValueError(f"Invalid block_structure '{bs}'. Must be a list of two ints [rows, cols].")
block_rows, block_cols = bs
num_br = int(ceil(rows / block_rows))
num_bc = int(ceil(cols / block_cols))
# allocate per-block scale and zero_point
self._scale = torch.empty((num_br, num_bc), dtype=observed.dtype, device=observed.device)
self._zero_point = torch.empty((num_br, num_bc), dtype=observed.dtype, device=observed.device)
# compute qparams for each block
for i in range(num_br):
r0 = i * block_rows
r1 = min((i + 1) * block_rows, rows)
for j in range(num_bc):
c0 = j * block_cols
c1 = min((j + 1) * block_cols, cols)
# reduce across both dims to get one scale and zp per block
scale_bp, zp_bp = self.calculate_qparams(
observed[r0:r1, c0:c1], reduce_dims=(0, 1)
)
self._scale[i, j] = scale_bp
self._zero_point[i, j] = zp_bp

return self._scale, self._zero_point

Expand Down
26 changes: 26 additions & 0 deletions tests/llmcompressor/modifiers/quantization/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,32 @@ def q_config_kwargs(config_0, config_1):
)
)

@pytest.fixture
def block_q_config_kwargs():
return dict(
config_groups=dict(
group_block=dict(
targets=["Linear"],
input_activations=dict(
num_bits=8, symmetric=True, strategy="group", group_size=128
),
weights=dict(
num_bits=8,
symmetric=True,
strategy="block",
block_structure=[128, 128],
),
),
)
)

def test_block_strategy_parsing(block_q_config_kwargs):
modifier = GPTQModifier(**block_q_config_kwargs)
resolved = modifier.resolve_quantization_config()
w_scheme = resolved.config_groups["group_block"].weights
assert w_scheme.strategy == "block"
assert w_scheme.block_structure == [128, 128]


@pytest.mark.parametrize(
"has_actorder,actorder,config_0,config_1,expected_0,expected_1",
Expand Down
Loading