Skip to content

Commit 911339b

Browse files
committed
wip
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent c9d0f4f commit 911339b

File tree

11 files changed

+33
-41
lines changed

11 files changed

+33
-41
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
150150

151151
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
152152

153+
config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
154+
tensors = BatchedMMTensors.make_tensors(config)
155+
153156
per_act_token_quant = False
154157

155158
if block_shape is not None and not use_fp8_w8a8:

tests/kernels/moe/test_pplx_moe.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,21 @@
1818
except ImportError:
1919
has_pplx = False
2020

21-
#from tests.kernels.quant_utils import native_w8a8_block_matmul
2221
from tests.kernels.moe.utils import (make_test_weights, naive_batched_moe,
2322
torch_moe2)
24-
from tests.pplx_utils import ProcessGroupInfo, parallel_launch
2523
from vllm.config import VllmConfig, set_current_vllm_config
26-
from vllm.model_executor.layers.fused_moe import (BatchedTritonExperts,
27-
FusedMoEConfig,
28-
FusedMoEModularKernel,
29-
fused_topk,
30-
get_default_config,
31-
override_config)
24+
from vllm.model_executor.layers.fused_moe import override_config
25+
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
26+
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
27+
from vllm.model_executor.layers.fused_moe.modular_kernel import (
28+
FusedMoEModularKernel)
3229
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
33-
BatchedPrepareAndFinalize, NaiveBatchedExperts)
30+
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
3431
from vllm.platforms import current_platform
3532
from vllm.utils import round_up
3633

34+
from .deepep_utils import ProcessGroupInfo, parallel_launch
35+
3736
requires_pplx = pytest.mark.skipif(
3837
not has_pplx,
3938
reason="Requires PPLX kernels",
@@ -542,7 +541,7 @@ def _pplx_moe(
542541
qtype: Optional[torch.dtype] = None,
543542
per_act_token_quant: bool = False,
544543
block_shape: Optional[list[int]] = None,
545-
use_internode: bool,
544+
use_internode: bool = False,
546545
):
547546
if use_internode:
548547
uid = nvshmem_get_unique_id(

tests/kernels/moe/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def torch_moe2(
106106
a, a_scale = moe_kernel_quantize_input(a, None, quant_type,
107107
per_act_token_quant, block_shape)
108108

109-
print(f"XXX {quant_type} {block_shape} {a.shape} {a_scale}")
109+
#print(f"XXX {quant_type} {block_shape} {a.shape} {a_scale}")
110110

111111
out = torch.zeros(M * topk,
112112
w2.shape[1],

vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,9 @@ def __init__(self,
6565
max_num_tokens=self.max_num_tokens,
6666
world_size=self.world_size,
6767
dp_size=self.dp_size,
68-
block_shape=self.block_shape,
68+
block_shape=self.block_shape, # type: ignore[arg-type]
6969
) if self.allow_deep_gemm else None
7070

71-
assert (self.batched_triton_experts is not None
72-
or (self.allow_deep_gemm
73-
and self.batched_deep_gemm_experts is not None))
74-
7571
assert (self.batched_deep_gemm_experts is not None
7672
or self.batched_triton_experts is not None)
7773

@@ -96,6 +92,7 @@ def workspace_shapes(
9692
# workspaces so we can be pessimistic here and allocate for DeepGemm
9793
# even if we fall back to triton later, e.g. if expert maps are set.
9894
if self.allow_deep_gemm:
95+
assert self.batched_deep_gemm_experts is not None
9996
return self.batched_deep_gemm_experts.workspace_shapes(
10097
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
10198
else:

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -347,19 +347,14 @@ def cutlass_moe_fp8(
347347
a2_scale.numel() != 1 if a2_scale is not None else False)
348348
per_out_ch = w1_scale.numel() != w1_q.shape[0]
349349

350-
out_dtype = a.dtype
351-
352-
if out_dtype is None:
353-
out_dtype = a.dtype
354-
355350
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(
356351
0)
357352

358353
fn = mk.FusedMoEModularKernel(
359354
MoEPrepareAndFinalizeNoEP(),
360355
CutlassExpertsFp8(
361356
max_experts_per_worker=num_experts,
362-
out_dtype=out_dtype,
357+
out_dtype=a.dtype,
363358
per_act_token_quant=per_act_token,
364359
per_out_ch_quant=per_out_ch,
365360
use_batched_format=False,

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -627,11 +627,6 @@ def __init__(
627627
block_shape: Optional[list[int]] = None,
628628
per_act_token_quant: bool = False,
629629
):
630-
super().__init__()
631-
assert not use_fp8_w8a8, "NYI"
632-
assert not use_int8_w8a8, "NYI"
633-
assert not use_int8_w8a16, "NYI"
634-
assert not use_int4_w4a16, "NYI"
635630
super().__init__(
636631
FusedMoEQuantConfig.make(
637632
use_fp8_w8a8=use_fp8_w8a8,
@@ -641,6 +636,10 @@ def __init__(
641636
per_act_token_quant=per_act_token_quant,
642637
block_shape=block_shape,
643638
))
639+
assert not use_fp8_w8a8, "NYI"
640+
assert not use_int8_w8a8, "NYI"
641+
assert not use_int8_w8a16, "NYI"
642+
assert not use_int4_w4a16, "NYI"
644643
self.max_num_tokens = max_num_tokens
645644
self.world_size = world_size
646645
self.dp_size = dp_size
@@ -928,7 +927,8 @@ def apply(
928927
intermediate_cache2 = _resize_cache(workspace2,
929928
(E, max_num_tokens, N // 2))
930929

931-
intermediate_cache1.fill_(0)
930+
if self.use_fp8_w8a8:
931+
intermediate_cache1.fill_(0)
932932

933933
#print(f"A1_SCALES {a1q_scale.shape}")
934934

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,6 @@ def init_prepare_finalize(self, moe: FusedMoEConfig,
9393
block_shape=moe.block_shape,
9494
)
9595

96-
logger.debug("All2All %s, %s = %s/%s", moe.quant_dtype,
97-
moe.block_shape, hidden_dim_bytes, hidden_scale_bytes)
98-
9996
all_to_all_args = dict(
10097
max_num_tokens=moe.max_num_tokens,
10198
num_experts=moe.num_experts,
@@ -223,7 +220,8 @@ def select_gemm_impl(
223220
self,
224221
prepare_finalize: FusedMoEPrepareAndFinalize,
225222
moe: FusedMoEConfig
226-
):
223+
) -> FusedMoEPermuteExpertsUnpermute:
224+
227225
assert self.fused_experts == fused_experts
228226

229227
all2all_manager = get_ep_group().device_communicator.all2all_manager
@@ -664,7 +662,6 @@ def __init__(
664662

665663
logger.debug("MODEL DTYPE %s", model_dtype)
666664

667-
# TODO: put quant info into FusedMoEConifg here
668665
moe = FusedMoEConfig.make(
669666
num_experts=self.global_num_experts,
670667
experts_per_token=top_k,

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,6 @@ def _moe_problem_size(
8888

8989

9090
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
91-
92-
9391
class FusedMoEPrepareAndFinalize(ABC):
9492
"""
9593
An abstract base class for the [Quantize-Prepare] and [Finalize] steps

vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def workspace_shapes(
6565
# Note: the deep gemm workspaces are strictly larger than the triton
6666
# workspaces so we can be pessimistic here and allocate for DeepGemm
6767
# even if we fall back to triton later, e.g. if expert maps are set.
68-
if (self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K)):
68+
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
6969
return self.deep_gemm_expert.workspace_shapes(
7070
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
7171
else:

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
581581
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
582582
requires_grad=False)
583583

584-
def select_gemm_impl(self, prepare_finalize, moe):
584+
def select_gemm_impl(
585+
self,
586+
prepare_finalize: FusedMoEPrepareAndFinalize,
587+
moe: MoEConfig,
588+
) -> FusedMoEPermuteExpertsUnpermute:
585589
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
586590
CutlassExpertsFp8)
587591

0 commit comments

Comments
 (0)