Skip to content

Commit bb5a4ed

Browse files
committed
fp8 baselines working
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 5c4bf25 commit bb5a4ed

File tree

3 files changed

+19
-10
lines changed

3 files changed

+19
-10
lines changed

tests/kernels/moe/test_pplx_moe.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,9 @@ def test_fused_moe_batched_experts(
170170

171171
with set_current_vllm_config(vllm_config):
172172
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
173-
baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
173+
baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids) # only for baseline
174174
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
175-
batched_output = naive_batched_moe(a, w1, w2, topk_weight, topk_ids)
175+
batched_output = naive_batched_moe(a, w1, w2, topk_weight, topk_ids) # pick torch_experts or this
176176

177177
torch.testing.assert_close(baseline_output,
178178
torch_output,
@@ -666,11 +666,14 @@ def test_pplx_moe(
666666
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
667667
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
668668

669-
_, w1, w1_s, _, w2, w2_s = make_test_weights(e,
670-
n,
671-
k,
672-
quant_dtype=quant_dtype,
673-
block_shape=block_shape)
669+
_, w1, w1_s, _, w2, w2_s = make_test_weights(
670+
e,
671+
n,
672+
k,
673+
quant_dtype=quant_dtype,
674+
block_shape=block_shape,
675+
per_act_token_quant=per_act_token_quant,
676+
)
674677

675678
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk,
676679
w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape,

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from vllm.logger import init_logger
1515
from vllm.model_executor.layers.quantization.base_config import (
1616
QuantizationConfig)
17+
from vllm.utils import cdiv
18+
1719

1820
logger = init_logger(__name__)
1921

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
get_config_dtype_str, try_get_optimal_moe_config)
1414
from vllm.model_executor.layers.fused_moe.utils import (
1515
_resize_cache, moe_kernel_quantize_input)
16+
from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast
1617

1718

1819
@triton.jit
@@ -555,13 +556,17 @@ def prepare(
555556
rhs_a1_scale = a1_scale[:topks.numel()][topks]
556557
else:
557558
rhs_a1_scale = None
558-
b_a1[idx, :rows, :], b_a1_scale[idx] = (moe_kernel_quantize_input(
559+
b_a1[idx, :rows, :], b_s = (moe_kernel_quantize_input(
559560
rhs,
560561
rhs_a1_scale,
561562
quant_config.quant_dtype,
562563
quant_config.per_act_token_quant,
563564
quant_config.block_shape,
564565
))
566+
if quant_config.is_per_tensor:
567+
b_a1_scale[idx] = b_s
568+
else:
569+
b_a1_scale[idx, :rows] = b_s[:rows]
565570
else:
566571
b_a1[idx, :rows, :] = rhs
567572

@@ -670,8 +675,7 @@ def dequant(self, t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
670675
if self.quant_config.is_per_act_token or self.quant_config.is_per_tensor:
671676
return t.to(f32) * scale
672677
else:
673-
t32 = t.to(f32).view(-1, self.quant_config.block_shape[1])
674-
return (t32 * scale.view(-1, 1)).view(t.shape)
678+
return t.to(f32) * group_broadcast(scale, t.shape)
675679

676680
def apply(
677681
self,

0 commit comments

Comments
 (0)