Simple quantization, compatible with vllm/sglang.
git clone https://github.com/LambdaLabsML/openquant.git
cd openquant
python compress_fp8.py -m Qwen/Qwen3-32B
vllm serve Qwen3-32B-FP8
Model/quantization support:
Model | fp8 | awq |
---|---|---|
Qwen3 | ✅ | ✅ |
Qwen3 MoE | ✅ | |
Llama 3 | ✅ | ✅ |
Llama 4 | ✅ | |
Gemma 3 | ✅ | ✅ |
Mistral | ✅ | ✅ |
For contributing new model architectures, see examples in openquant/models.py.
python compress_fp8.py -m Qwen/Qwen3-32B
tl;dr:
model size * 0.5
throughput * 1.2ish
(with a lot of caveats)
Models today are usually trained in bf16
, which is a decimal number stored in 16 bits (2 bytes). At the billions of parameter scale, these add up VERY quickly. The main reason for quantizing a model from bf16
to fp8
is memory reduction.
For example meta-llama/Llama-3.3-70B-Instruct has 70 billion parameters, which at
bf16
is 140 billion bytes or 140 GB of data. A single H100 GPU has 80GB of GPU RAM, so you'd need at LEAST 2xH100 to serve it, but likely more for kv cache space. If you halve the number of bytes, it would only take 70 GB, enabling it to comfortably fit on 2xH100s, and just fit barely on 1xH100.
Starting with NVIDIA H100 GPU, GPUs have hardware support for 8 bit floating point numbers (fp8
), meaning fp8
performance is >= bf16
performance (mostly). This performance gain comes from a couple of reasons:
- Model takes less GPU ram => more space for kv cache. Modern inference libraries (like vllm/sglang) will have higher/more stable performance with more kv cache
- Model parameters are half as big => less GPU memory bandwidth/more data can fit in cache
- Depending on the GPU, fp8 flops are just higher than bf16 flops. E.g. See H100 specifications; bfloat16 has ~2k teraflops and fp8 has ~4k teraflops
First some facts:
- Model parameters are typically stored using
torch.float8_e4m3fn
- This format has
1
sign bit,4
bits for exponent, and3
bits for mantissa - Values can be between
[-448, +448]
- 256 representable values
- No support for infinity
- This format has
Here are some sample random numbers at f32/bf16/fp8 (you can see the precision loss as store in less bits):
>>> q = torch.randn(18); q
tensor([-0.272713, -0.222072, -0.491148, 0.589126, 0.489998, 1.777003])
>>> q.to(dtype=torch.bfloat16)
tensor([-0.273438, -0.221680, -0.490234, 0.589844, 0.490234, 1.773438])
>>> q.to(dtype=torch.float8_e4m3fn)
tensor([-0.281250, -0.218750, -0.500000, 0.562500, 0.500000, 1.750000])
And here is how all the representable values are distributed (notice how there are waaaaay more values closer to 0! ):
So this leads us with two questions for quantization:
bf16
can store values between[-3.38953e+38, +3.38953e+38]
, how do we fit that into fp8 range of[-448, +448]
?- How do we take advantage of the distribution of values in fp8?
When quantizing a tensor from bf16
to fp8
, we don't just convert it to the dtype like I showed above.
Instead we do the following:
- Compute the largest value of the tensor (the scale)
- Divide the tensor by the scale (so the values are between min value and max value)
- Store both the quantized tensor & the scale
We need to compute & store this scale to handle values that are larger than the range that fp8 can store (-448 to 448).
Let's see this in action:
TODO
For compatibility with things like VLLM there's a couple things we need to do:
- Add the
weight_scale
as a parameter to each of theLinear
layers. This basically means just replace theLinear
layer with thisPackedLinear
, whereweight
is thefp8
tensor, andweight_scale
is the scale.
class PackedLinear(torch.nn.Module):
def __init__(self, weight: torch.Tensor, weight_scale: torch.Tensor):
super().__init__()
self.weight = torch.nn.Parameter(weight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
- Add a
quantization_config
into the model's config. This will also appear in theconfig.json
file in the huggingface repo of the model.
model.config.quantization_config = {
"quant_method": "fp8",
"is_checkpoint_fp8_serialized": True,
"activation_scheme": "dynamic",
"weight_block_size": ..., # `None` or `[block_size, block_size]`
"ignored_layers": ..., # list of module names that are not quantized
}
And that's all we need to do for vllm!
NOTE: some models don't support all layers being quantized. For example, vllm does not support the decoder.mlp.gate
linear layer being quantized in Qwen3 MoE models.
MIT License
Copyright (c) 2025 Lambda Labs Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.