Skip to content

Commit 332d4cb

Browse files
[Feature][Quantization] MXFP4 support for MOE models (#17888)
Signed-off-by: Felix Marty <felmarty@amd.com> Signed-off-by: Bowen Bao <bowenbao@amd.com> Signed-off-by: Felix Marty <Felix.Marty@amd.com> Co-authored-by: Bowen Bao <bowenbao@amd.com>
1 parent bf03ff3 commit 332d4cb

File tree

15 files changed

+875
-106
lines changed

15 files changed

+875
-106
lines changed

docs/features/quantization/quark.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,28 @@ python3 quantize_quark.py --model_dir meta-llama/Llama-2-70b-chat-hf \
229229
--model_export hf_format \
230230
--tasks gsm8k
231231
```
232+
233+
## Using MXFP4 models
234+
235+
vLLM supports loading MXFP4 models quantized offline through AMD Quark, compliant with [Open Compute Project (OCP) specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
236+
237+
The scheme currently only supports dynamic quantization for activations.
238+
239+
Example usage, after installing the latest AMD Quark release:
240+
241+
```bash
242+
vllm serve fxmarty/qwen_1.5-moe-a2.7b-mxfp4 --tensor-parallel-size 1
243+
```
244+
245+
A simulation of the matrix multiplication execution in MXFP4 can be run on devices that do not support MXFP4 operations natively (e.g. AMD Instinct MI325, MI300 and MI250), dequantizing weights from MXFP4 to half precision on the fly, using a fused kernel. This is useful e.g. to evaluate MXFP4 models using vLLM, or alternatively to benefit from the ~4x memory savings (compared to float16 and bfloat16).
246+
247+
To generate offline models quantized using MXFP4 data type, the easiest approach is to use AMD Quark's [quantization script](https://quark.docs.amd.com/latest/pytorch/example_quark_torch_llm_ptq.html), as an example:
248+
249+
```bash
250+
python quantize_quark.py --model_dir Qwen/Qwen1.5-MoE-A2.7B-Chat \
251+
--quant_scheme w_mxfp4_a_mxfp4_sym \
252+
--output_dir qwen_1.5-moe-a2.7b-mxfp4 \
253+
--skip_evaluation \
254+
--model_export hf_format \
255+
--group_size 32
256+
```

tests/kernels/moe/test_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def test_fused_moe(
174174
use_int8_w8a8=False,
175175
use_int8_w8a16=False,
176176
use_int4_w4a16=False,
177+
use_mxfp4_w4a4=False,
177178
per_act_token_quant=False,
178179
block_shape=None)
179180

tests/kernels/moe/test_mxfp4_moe.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import importlib
5+
import importlib.metadata
6+
from dataclasses import dataclass
7+
8+
import pytest
9+
import torch
10+
from packaging import version
11+
12+
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
13+
"quark") is not None and version.parse(
14+
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
15+
16+
17+
@dataclass
18+
class ModelCase:
19+
model_id: str
20+
tp: int
21+
22+
23+
@pytest.mark.parametrize('model_case', [
24+
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1),
25+
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
26+
ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1)
27+
])
28+
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE,
29+
reason="amd-quark>=0.9 is not available")
30+
def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
31+
if torch.cuda.device_count() < model_case.tp:
32+
pytest.skip(f"This test requires >={model_case.tp} gpus, got only "
33+
f"{torch.cuda.device_count()}")
34+
35+
with vllm_runner(model_case.model_id,
36+
tensor_parallel_size=model_case.tp,
37+
load_format="dummy") as llm:
38+
39+
# TODO: llm.apply_model(check_model) currently relies on V0 internals.
40+
# Re-enable this later.
41+
# def check_model(model):
42+
# layer = model.model.layers[0]
43+
44+
# qkv_proj = layer.self_attn.qkv_proj
45+
46+
# assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
47+
# assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4)
48+
49+
# assert isinstance(layer.mlp.experts.quant_method,
50+
# QuarkW4A4MXFp4MoEMethod)
51+
52+
# if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
53+
# llm.apply_model(check_model)
54+
55+
output = llm.generate_greedy("Today I am in the French Alps and",
56+
max_tokens=20)
57+
assert output

tests/quantization/reference_mxfp4.py

Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import torch
4+
5+
BFLOAT16_EXP_BIAS = 127
6+
BFLOAT16_MANTISSA_BITS = 7
7+
BFLOAT16_EXP_BITS = 8
8+
9+
FLOAT16_EXP_BIAS = 15
10+
FLOAT16_MANTISSA_BITS = 10
11+
FLOAT16_EXP_BITS = 5
12+
13+
FLOAT8_E8M0_MAX_EXP = 127
14+
FLOAT4_EXP_BIAS = 1
15+
FLOAT4_MANTISSA_BITS = 1
16+
17+
FLOAT16_VAL_TO_ADD = (1 << (FLOAT16_MANTISSA_BITS - FLOAT4_MANTISSA_BITS - 1))
18+
FLOAT16_SIGN_EXPONENT_MASK = ((
19+
(1 << (FLOAT16_EXP_BITS + 1)) - 1) << FLOAT16_MANTISSA_BITS)
20+
21+
BFLOAT16_VAL_TO_ADD = (1 <<
22+
(BFLOAT16_MANTISSA_BITS - FLOAT4_MANTISSA_BITS - 1))
23+
BFLOAT16_SIGN_EXPONENT_MASK = ((
24+
(1 << (BFLOAT16_EXP_BITS + 1)) - 1) << BFLOAT16_MANTISSA_BITS)
25+
26+
27+
def e8m0_to_half(scale, half_dtype: torch.dtype):
28+
assert scale.dtype == torch.uint8
29+
30+
scale_exp = scale.to(torch.int16) - 127
31+
32+
# This can be implemented with bitwise operations in a proper kernel.
33+
scale_half = 2.0**(scale_exp.to(torch.float))
34+
35+
return scale_half.to(half_dtype)
36+
37+
38+
def upcast_fp4_to_fp16_or_bf16(val, float_dtype: torch.dtype,
39+
half_exp_bias: int, half_mantissa_bits: int):
40+
assert val.dtype == torch.uint8
41+
42+
unpacked = torch.zeros(*val.shape[:-1],
43+
val.shape[-1] * 2,
44+
dtype=torch.uint8,
45+
device=val.device)
46+
unpacked[..., 1::2] = (val >> 4) & 0x0F # Extract high 4 bits.
47+
unpacked[..., ::2] = val & 0x0F # Extract low 4 bits.
48+
49+
# Takes one float4 values represented as b0000xxxx,
50+
# and converts it to the corresponding float16 value.
51+
52+
sign = unpacked >> 3
53+
54+
exp = (unpacked >> 1) & 3
55+
new_mantissa = unpacked & 1
56+
57+
# if exp == 0 and new_mantissa == 0:
58+
# new_exp = 0
59+
# else:
60+
# new_exp = exp - FLOAT4_EXP_BIAS + FLOAT16_EXP_BIAS
61+
62+
# int8_t works with float16, but may overflow with bfloat16.
63+
new_exp = exp - FLOAT4_EXP_BIAS + half_exp_bias
64+
65+
# Cast b0000 to 0. in fp16/bf16.
66+
new_exp = new_exp * torch.logical_or(exp > 0, new_mantissa > 0)
67+
68+
# Cast b0001 to 0.5 in fp16/bf16.
69+
new_mantissa = torch.logical_and(new_mantissa, exp > 0)
70+
71+
new_mantissa = new_mantissa.to(torch.int32)
72+
new_exp = new_exp.to(torch.int32)
73+
sign = sign.to(torch.int32)
74+
75+
qdq_val = (sign << 15) + (new_exp << half_mantissa_bits) + (
76+
new_mantissa << (half_mantissa_bits - 1))
77+
78+
assert qdq_val.max() <= 65535
79+
assert qdq_val.min() >= 0
80+
qdq_val = qdq_val.to(torch.uint16)
81+
82+
result = qdq_val.view(float_dtype)
83+
84+
return result
85+
86+
87+
def dq_mxfp4_torch(x: torch.Tensor, scale: torch.Tensor,
88+
float_dtype: torch.dtype) -> torch.Tensor:
89+
assert x.dtype == torch.uint8
90+
assert scale.dtype == torch.uint8
91+
92+
if float_dtype == torch.float16:
93+
half_exp_bias = FLOAT16_EXP_BIAS
94+
half_mantissa_bits = FLOAT16_MANTISSA_BITS
95+
elif float_dtype == torch.bfloat16:
96+
half_exp_bias = BFLOAT16_EXP_BIAS
97+
half_mantissa_bits = BFLOAT16_MANTISSA_BITS
98+
99+
scale_half = e8m0_to_half(scale, half_dtype=float_dtype)
100+
101+
x_half = upcast_fp4_to_fp16_or_bf16(x,
102+
float_dtype=float_dtype,
103+
half_exp_bias=half_exp_bias,
104+
half_mantissa_bits=half_mantissa_bits)
105+
106+
x_half = x_half.reshape(*x_half.shape[:-1], -1, 32)
107+
x_half = x_half * scale_half[..., None]
108+
x_half = x_half.reshape(*x_half.shape[:-2], -1)
109+
110+
return x_half
111+
112+
113+
def fp16_to_fp4_simulate(val, half_mantissa_bits: int, half_exp_bits: int,
114+
half_exp_bias: int):
115+
# Casts an fp16/bf16 input to the restricted values of float4_e2m1,
116+
# that is to say [0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0,
117+
# -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0].
118+
119+
float_type = val.dtype
120+
121+
# "rshift_cuda" not implemented for 'UInt16'
122+
val_view = val.view(torch.int16) #.to(torch.int32)
123+
124+
exp = val_view >> half_mantissa_bits
125+
exp = exp & ((1 << half_exp_bits) - 1)
126+
127+
exp = exp.view(torch.uint16).to(torch.int32)
128+
129+
sign = (val_view >> (half_mantissa_bits + half_exp_bits)) & 1
130+
131+
mantissa_last = (val_view >> (half_mantissa_bits - 1)) & 1
132+
133+
exp_unbias = exp - half_exp_bias
134+
new_exp = exp_unbias + FLOAT4_EXP_BIAS
135+
136+
exp_shift = (new_exp <= 0) * (1 - new_exp)
137+
138+
# Typically 9.
139+
# Take the min to prevent overflow on `uint16_t half`. This is the case for
140+
# very small values, correctly mapped to `round_close`.
141+
tail_bits = half_mantissa_bits - FLOAT4_MANTISSA_BITS + exp_shift
142+
tail_bits[tail_bits >= 16] = 16
143+
144+
mantissa_plus_one = val_view & ((1 << (half_mantissa_bits + 1)) - 1)
145+
146+
half = 1 << (tail_bits - 1)
147+
148+
tail = mantissa_plus_one & ((1 << tail_bits) - 1)
149+
150+
round_close = (tail < half) # round towards 0
151+
round_away = (tail > half) # round away from 0
152+
tie = tail == half
153+
154+
new_mantissa_close = torch.zeros(val.shape,
155+
device=val.device,
156+
dtype=torch.bool)
157+
new_exp_close = torch.zeros(val.shape,
158+
device=val.device,
159+
dtype=torch.uint16)
160+
161+
new_mantissa_away = torch.zeros(val.shape,
162+
device=val.device,
163+
dtype=torch.bool)
164+
new_exp_away = torch.zeros(val.shape,
165+
device=val.device,
166+
dtype=torch.uint16)
167+
168+
new_exp_tie = torch.zeros(val.shape, device=val.device, dtype=torch.uint16)
169+
170+
# 1. round down
171+
# if new_exp == 0: # case [0.5, 0.749999]
172+
# new_mantissa = 0
173+
# elif new_exp < 0: # case [0, 0.24999]
174+
# new_mantissa = 0
175+
# else:
176+
# new_mantissa = mantissa_last
177+
178+
new_mantissa_close = (new_exp > 0) * mantissa_last
179+
new_exp_close = exp
180+
181+
# # 2. round up
182+
# if new_exp <= 0: # case [0.250001, 0.499999] and [0.75001, 0.99999]
183+
# new_mantissa = 0
184+
# new_exp += 1
185+
# elif mantissa_last == 0:
186+
# new_mantissa = 1
187+
# else:
188+
# new_mantissa = 0
189+
# new_exp += 1
190+
191+
new_mantissa_away = torch.logical_and(new_exp > 0, mantissa_last == 0)
192+
new_exp_away = exp + torch.logical_or(new_exp <= 0, mantissa_last == 1)
193+
194+
# # 3. tie
195+
# 0.25 -> 0. (handled by `exp > (half_exp_bias - 2)`)
196+
# 0.75 -> 1.
197+
# 1.25 -> 1.
198+
# 1.75 -> 2.
199+
# 2.5 -> 2.
200+
# 3.5 -> 4.
201+
# 5. -> 4.
202+
new_exp_tie = (exp > (half_exp_bias - 2)) * (exp + (mantissa_last == 1))
203+
204+
# Gather round up, round down and tie.
205+
new_exp = round_away * new_exp_away \
206+
+ round_close * new_exp_close \
207+
+ tie * new_exp_tie
208+
209+
new_mantissa = round_away * new_mantissa_away \
210+
+ round_close * new_mantissa_close
211+
212+
# if new_exp > 3:
213+
# new_mantissa = 1
214+
new_mantissa = new_mantissa + (new_exp >
215+
(2 + half_exp_bias)) * (new_mantissa == 0)
216+
217+
# Clamp the exponent to acceptable values.
218+
new_exp = (new_exp >= (half_exp_bias - 2)) * torch.clamp(
219+
new_exp, half_exp_bias - 2, half_exp_bias + 2)
220+
221+
sign = sign.to(torch.int32)
222+
new_mantissa = new_mantissa.to(torch.int32)
223+
224+
qdq_val = (sign << 15) + (new_exp << half_mantissa_bits) + (
225+
new_mantissa << (half_mantissa_bits - 1))
226+
227+
assert qdq_val.max() <= 65535
228+
assert qdq_val.min() >= 0
229+
assert qdq_val.dtype == torch.int32
230+
qdq_val = qdq_val.to(torch.uint16)
231+
232+
result = qdq_val.view(float_type)
233+
return result
234+
235+
236+
def qdq_mxfp4_torch(x: torch.Tensor,
237+
scale_calculation_mode: str = "even") -> torch.Tensor:
238+
half_dtype = x.dtype
239+
240+
if half_dtype == torch.float16:
241+
half_mantissa_bits = FLOAT16_MANTISSA_BITS
242+
half_exp_bits = FLOAT16_EXP_BITS
243+
half_exp_bias = FLOAT16_EXP_BIAS
244+
val_to_add = FLOAT16_VAL_TO_ADD
245+
sign_exponent_mask = FLOAT16_SIGN_EXPONENT_MASK
246+
elif half_dtype == torch.bfloat16:
247+
half_mantissa_bits = BFLOAT16_MANTISSA_BITS
248+
half_exp_bits = BFLOAT16_EXP_BITS
249+
half_exp_bias = BFLOAT16_EXP_BIAS
250+
val_to_add = BFLOAT16_VAL_TO_ADD
251+
sign_exponent_mask = BFLOAT16_SIGN_EXPONENT_MASK
252+
else:
253+
raise ValueError("not implemented")
254+
255+
x = x.reshape(*x.shape[:-1], -1, 32)
256+
257+
block_max = torch.max(torch.abs(x), dim=-1).values
258+
259+
block_max = block_max.view(torch.uint16).to(torch.int32)
260+
261+
block_max_uint = torch.bitwise_and(block_max + val_to_add,
262+
sign_exponent_mask)
263+
264+
assert block_max_uint.max() <= 65535
265+
assert block_max_uint.min() >= 0
266+
assert block_max_uint.dtype == torch.int32
267+
block_max_uint = block_max_uint.to(torch.uint16)
268+
269+
block_max = block_max_uint.view(half_dtype)
270+
271+
scale_exp = FLOAT8_E8M0_MAX_EXP + torch.floor(torch.log2(block_max)).to(
272+
torch.int32) - 2
273+
274+
scale_exp = torch.clamp(scale_exp, 0, 2 * FLOAT8_E8M0_MAX_EXP)
275+
276+
scale = 2.0**(scale_exp - FLOAT8_E8M0_MAX_EXP)
277+
scale = scale.to(half_dtype)
278+
279+
x = x / scale[..., None]
280+
281+
x_fp4 = fp16_to_fp4_simulate(x,
282+
half_exp_bits=half_exp_bits,
283+
half_mantissa_bits=half_mantissa_bits,
284+
half_exp_bias=half_exp_bias)
285+
286+
x_fp4 = x_fp4 * scale[..., None]
287+
return x_fp4.reshape(*x_fp4.shape[:-2], -1)

0 commit comments

Comments
 (0)