Skip to content

Commit daa9b0a

Browse files
committed
review comments + test fixes
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 90772e8 commit daa9b0a

18 files changed

+382
-348
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,12 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
8585

8686
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
8787

88-
if block_shape is not None and not use_fp8_w8a8:
88+
if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8:
8989
pytest.skip("Don't test blocking for non-quantized types.")
9090

91+
if per_act_token_quant and block_shape is not None:
92+
pytest.skip("Skip illegal quantization test.")
93+
9194
if dtype.itemsize == 1:
9295
act_dtype = torch.bfloat16
9396
quant_dtype = dtype
@@ -201,11 +204,11 @@ def test_fused_moe_batched_experts(
201204

202205
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
203206

204-
if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None:
207+
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
205208
pytest.skip("Skip quantization test for non-quantized type")
206209

207210
if per_act_token_quant and block_shape is not None or topk > e:
208-
pytest.skip("Skip illegal quantization test")
211+
pytest.skip("Skip illegal quantization test.")
209212

210213
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
211214
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
@@ -226,9 +229,18 @@ def test_fused_moe_batched_experts(
226229

227230
with set_current_vllm_config(vllm_config):
228231
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
229-
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
230-
w2_s, quant_dtype, per_act_token_quant,
231-
block_shape)
232+
batched_output = batched_moe(
233+
a,
234+
w1,
235+
w2,
236+
topk_weight,
237+
topk_ids,
238+
w1_scale=w1_s,
239+
w2_scale=w2_s,
240+
quant_dtype=quant_dtype,
241+
per_act_token_quant=per_act_token_quant,
242+
block_shape=block_shape,
243+
)
232244
baseline_output = torch_experts(
233245
a,
234246
w1,
@@ -240,9 +252,19 @@ def test_fused_moe_batched_experts(
240252
quant_dtype=quant_dtype,
241253
per_act_token_quant=per_act_token_quant,
242254
block_shape=block_shape)
243-
triton_output = triton_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
244-
w2_s, quant_dtype, per_act_token_quant,
245-
block_shape)
255+
256+
triton_output = triton_moe(
257+
a,
258+
w1,
259+
w2,
260+
topk_weight,
261+
topk_ids,
262+
w1_scale=w1_s,
263+
w2_scale=w2_s,
264+
quant_dtype=quant_dtype,
265+
per_act_token_quant=per_act_token_quant,
266+
block_shape=block_shape,
267+
)
246268

247269
torch.testing.assert_close(triton_output,
248270
baseline_output,

tests/kernels/moe/test_block_fp8.py

Lines changed: 48 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
1010
native_w8a8_block_matmul,
1111
per_block_cast_to_fp8)
12+
from tests.kernels.moe.utils import make_test_weights
1213
from vllm.config import VllmConfig, set_current_vllm_config
1314
from vllm.model_executor.layers.activation import SiluAndMul
14-
from vllm.model_executor.layers.fused_moe import fused_moe
15+
from vllm.model_executor.layers.fused_moe import fused_experts
1516
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
1617
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
1718
from vllm.model_executor.layers.fused_moe.fused_moe import (
@@ -55,13 +56,13 @@
5556
SEEDS = [0]
5657

5758

58-
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
59+
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, block_shape):
5960
"""Fused moe with block-wise quantization using native torch."""
6061
B, D = a.shape
62+
topk = topk_ids.size(1)
6163
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
6264
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
63-
score = torch.softmax(score, dim=-1, dtype=torch.float32)
64-
topk_weight, topk_ids = torch.topk(score, topk)
65+
6566
topk_weight = topk_weight.view(-1)
6667
topk_ids = topk_ids.view(-1)
6768

@@ -112,80 +113,59 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
112113

113114
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
114115

115-
factor_for_scale = 1e-2
116-
fp8_info = torch.finfo(torch.float8_e4m3fn)
117-
fp8_max, fp8_min = fp8_info.max, fp8_info.min
118-
119116
a = torch.randn((M, K), dtype=dtype) / 10
120-
121-
w1_bf16 = (torch.rand(
122-
(E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
123-
w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
124-
del w1_bf16
125-
126-
w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
127-
w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
128-
del w2_bf16
129-
130-
block_n, block_k = block_size[0], block_size[1]
131-
n_tiles_w1 = (2 * N + block_n - 1) // block_n
132-
n_tiles_w2 = (K + block_n - 1) // block_n
133-
k_tiles_w1 = (K + block_k - 1) // block_k
134-
k_tiles_w2 = (N + block_k - 1) // block_k
135-
136-
w1_s = torch.rand(
137-
(E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale
138-
w2_s = torch.rand(
139-
(E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale
140-
141117
score = torch.randn((M, E), dtype=dtype)
142118

119+
_, w1, w1_s, _, w2, w2_s = make_test_weights(E, N, K, dtype, torch.float8_e4m3fn,
120+
per_act_token_quant=False,
121+
block_shape=block_size)
122+
143123
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
144124
use_int8_w8a8=False,
145125
use_int8_w8a16=False,
146126
use_int4_w4a16=False,
147127
per_act_token_quant=False,
148128
block_shape=block_size)
149129

130+
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
131+
150132
# Set the context to avoid lots of warning spam.
151133
with set_current_vllm_config(vllm_config):
152-
out = fused_moe(
134+
ref_out = torch_w8a8_block_fp8_moe(
153135
a,
154136
w1,
155137
w2,
156-
score,
157-
topk,
158-
renormalize=False,
138+
w1_s,
139+
w2_s,
140+
topk_weights,
141+
topk_ids,
142+
block_size,
143+
)
144+
145+
out = fused_experts(
146+
a,
147+
w1,
148+
w2,
149+
topk_weights,
150+
topk_ids,
159151
use_fp8_w8a8=True,
160152
w1_scale=w1_s,
161153
w2_scale=w2_s,
162154
block_shape=block_size,
163155
)
164-
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk,
165-
block_size)
166156

167-
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
168-
m_out = m_fused_moe(a,
169-
w1,
170-
w2,
171-
topk_weights,
172-
topk_ids,
173-
global_num_experts=E,
174-
w1_scale=w1_s,
175-
w2_scale=w2_s)
176-
177-
#print(f"{out.sum()=}")
178-
#print(f"{ref_out.sum()=}")
179-
180-
rel_diff = (torch.mean(
181-
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
182-
torch.mean(torch.abs(ref_out.to(torch.float32))))
183-
assert rel_diff < 0.03
157+
m_out = m_fused_moe(
158+
a,
159+
w1,
160+
w2,
161+
topk_weights,
162+
topk_ids,
163+
w1_scale=w1_s,
164+
w2_scale=w2_s,
165+
)
184166

185-
rel_diff = (torch.mean(
186-
torch.abs(m_out.to(torch.float32) - ref_out.to(torch.float32))) /
187-
torch.mean(torch.abs(ref_out.to(torch.float32))))
188-
assert rel_diff < 0.03
167+
torch.testing.assert_close(out, ref_out, atol=0.03, rtol=0.03)
168+
torch.testing.assert_close(m_out, ref_out, atol=0.03, rtol=0.03)
189169

190170

191171
def fp8_perm(m, idx):
@@ -221,15 +201,13 @@ def _moe_unpermute(out, inv_perm, topk, K, topk_weight):
221201
return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
222202

223203

224-
def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
204+
def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, topk_weight, topk_ids,
225205
block_shape):
226206
"""Fused moe with block-wise quantization using DeepGemm grouped gemm."""
227207
num_groups = w1.shape[0]
228208
M, K = a.shape
229209
N = w2.shape[-1]
230-
231-
topk_weight, topk_ids, token_expert_indices = fused_topk(
232-
a, score.float(), topk, False)
210+
topk = topk_ids.size(1)
233211

234212
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
235213

@@ -282,40 +260,12 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
282260
block_size = [block_m, block_m]
283261
dtype = torch.bfloat16
284262

285-
fp8_info = torch.finfo(torch.float8_e4m3fn)
286-
fp8_max, fp8_min = fp8_info.max, fp8_info.min
287-
288263
a = torch.randn((M, K), dtype=dtype) / 10
289-
290-
w1_bf16 = ((torch.rand((E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 *
291-
fp8_max).clamp(min=fp8_min, max=fp8_max)
292-
293-
w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 *
294-
fp8_max).clamp(min=fp8_min, max=fp8_max)
295-
296264
score = torch.randn((M, E), dtype=dtype)
297265

298-
block_n, block_k = block_size[0], block_size[1]
299-
n_tiles_w1 = ((2 * N) + block_n - 1) // block_n
300-
k_tiles_w1 = (K + block_k - 1) // block_k
301-
n_tiles_w2 = (K + block_n - 1) // block_n
302-
k_tiles_w2 = (N + block_k - 1) // block_k
303-
304-
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
305-
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
306-
307-
w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
308-
w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
309-
310-
w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous()
311-
w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous()
312-
313-
assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128)
314-
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
315-
316-
for i in range(E):
317-
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
318-
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
266+
_, w1, w1_s, _, w2, w2_s = make_test_weights(E, N, K, dtype, torch.float8_e4m3fn,
267+
per_act_token_quant=False,
268+
block_shape=block_size)
319269

320270
# Note: for now use_compile will error out if the problem size is
321271
# large enough to trigger chunking. I'm leaving the flag and
@@ -325,17 +275,16 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
325275
use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024
326276
and current_platform.is_cuda_alike())
327277

278+
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
279+
328280
# Set the context to avoid lots of warning spam.
329281
with set_current_vllm_config(vllm_config):
330-
if M >= 128:
282+
if False and M >= 128:
331283
ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s,
332-
score, topk, block_size)
284+
topk_weights, topk_ids, block_size)
333285
else:
334-
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score,
335-
topk, block_size)
336-
337-
topk_weights, topk_ids, token_expert_indices = fused_topk(
338-
a, score.float(), topk, False)
286+
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weights,
287+
topk_ids, block_size)
339288

340289
if use_compile:
341290
deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,
@@ -361,11 +310,4 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
361310
graph.replay()
362311
torch.cuda.synchronize()
363312

364-
#print(f"{out.sum()=}")
365-
#print(f"{ref_out.sum()=}")
366-
367-
rel_diff = (torch.mean(
368-
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
369-
torch.mean(torch.abs(ref_out.to(torch.float32))))
370-
371-
assert rel_diff < 0.03
313+
torch.testing.assert_close(out, ref_out, atol=0.03, rtol=0.03)

tests/kernels/moe/test_block_int8.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from tests.kernels.quant_utils import (native_per_token_group_quant_int8,
1010
native_w8a8_block_matmul)
11+
from tests.kernels.moe.utils import make_test_weights
1112
from vllm.config import VllmConfig, set_current_vllm_config
1213
from vllm.model_executor.layers.activation import SiluAndMul
1314
from vllm.model_executor.layers.fused_moe import fused_moe
@@ -85,31 +86,34 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
8586
torch.manual_seed(seed)
8687
# Use a smaller factor for scale initialization to prevent large
8788
# values/overflow especially when output dtype might be float16
88-
factor_for_scale = 1e-2
89-
int8_info = torch.iinfo(torch.int8)
90-
int8_max, int8_min = int8_info.max, int8_info.min
89+
# factor_for_scale = 1e-2
90+
# int8_info = torch.iinfo(torch.int8)
91+
# int8_max, int8_min = int8_info.max, int8_info.min
9192

9293
a = torch.randn((M, K), dtype=dtype) / 10
94+
score = torch.randn((M, E), dtype=dtype)
9395

94-
w1_fp32 = (torch.rand(
95-
(E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * int8_max
96-
w1 = w1_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
96+
# w1_fp32 = (torch.rand(
97+
# (E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * int8_max
98+
# w1 = w1_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
9799

98-
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * int8_max
99-
w2 = w2_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
100+
# w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * int8_max
101+
# w2 = w2_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
100102

101-
block_n, block_k = block_size[0], block_size[1]
102-
n_tiles_w1 = (2 * N + block_n - 1) // block_n
103-
n_tiles_w2 = (K + block_n - 1) // block_n
104-
k_tiles_w1 = (K + block_k - 1) // block_k
105-
k_tiles_w2 = (N + block_k - 1) // block_k
103+
# block_n, block_k = block_size[0], block_size[1]
104+
# n_tiles_w1 = (2 * N + block_n - 1) // block_n
105+
# n_tiles_w2 = (K + block_n - 1) // block_n
106+
# k_tiles_w1 = (K + block_k - 1) // block_k
107+
# k_tiles_w2 = (N + block_k - 1) // block_k
106108

107-
w1_s = (torch.rand(
108-
(E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale)
109-
w2_s = (torch.rand(
110-
(E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale)
109+
# w1_s = (torch.rand(
110+
# (E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale)
111+
# w2_s = (torch.rand(
112+
# (E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale)
111113

112-
score = torch.randn((M, E), dtype=dtype)
114+
_, w1, w1_s, _, w2, w2_s = make_test_weights(E, N, K, dtype, torch.int8,
115+
per_act_token_quant=False,
116+
block_shape=block_size)
113117

114118
# Set the context to avoid lots of warning spam.
115119
with set_current_vllm_config(vllm_config):
@@ -129,7 +133,4 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
129133
block_size)
130134

131135
# Check results
132-
rel_diff = (torch.mean(
133-
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
134-
torch.mean(torch.abs(ref_out.to(torch.float32))))
135-
assert rel_diff < 0.06
136+
torch.testing.assert_close(out, ref_out, atol=0.06, rtol=0.06)

0 commit comments

Comments
 (0)