Skip to content

Commit c7d8724

Browse files
wenscarlmgoin
andauthored
[Core] FlashInfer CUTLASS fused MoE backend (NVFP4) (#20037)
Signed-off-by: shuw <shuw@nvidia.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
1 parent b38baab commit c7d8724

22 files changed

+1095
-271
lines changed

vllm/_custom_ops.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -956,11 +956,11 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
956956
c_strides, per_act_token, per_out_ch)
957957

958958

959-
def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor,
960-
a_scales: torch.Tensor, b_scales: torch.Tensor,
961-
alphas: torch.Tensor, problem_sizes: torch.Tensor,
962-
expert_offsets: torch.Tensor, sf_offsets: torch.Tensor,
963-
out_dtype: torch.dtype, device: torch.device):
959+
def cutlass_fp4_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
960+
b_tensors: torch.Tensor, a_scales: torch.Tensor,
961+
b_scales: torch.Tensor, alphas: torch.Tensor,
962+
problem_sizes: torch.Tensor,
963+
expert_offsets: torch.Tensor, sf_offsets: torch.Tensor):
964964
"""
965965
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
966966
the gemms for each combination based on the specified problem sizes.
@@ -977,14 +977,10 @@ def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor,
977977
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
978978
MMs used in the fused MoE operation.
979979
"""
980-
m_topk = a_tensors.shape[0]
981-
n = b_tensors.shape[1]
982-
c_shape = (m_topk, n)
983-
c = torch.empty(c_shape, device=device, dtype=out_dtype)
984-
torch.ops._C.cutlass_fp4_group_mm(c, a_tensors, b_tensors, a_scales,
985-
b_scales, alphas, problem_sizes,
986-
expert_offsets, sf_offsets)
987-
return c.to(out_dtype)
980+
return torch.ops._C.cutlass_fp4_group_mm(out_tensors, a_tensors, b_tensors,
981+
a_scales, b_scales, alphas,
982+
problem_sizes, expert_offsets,
983+
sf_offsets)
988984

989985

990986
# aqlm

vllm/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@
119119
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
120120
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
121121
VLLM_USE_DEEP_GEMM: bool = False
122+
VLLM_USE_FLASHINFER_MOE: bool = False
122123
VLLM_XGRAMMAR_CACHE_MB: int = 0
123124
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
124125
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
@@ -853,6 +854,10 @@ def get_vllm_port() -> Optional[int]:
853854
"VLLM_USE_DEEP_GEMM":
854855
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
855856

857+
# Allow use of FlashInfer CUTLASS kernels for fused moe ops.
858+
"VLLM_USE_FLASHINFER_MOE":
859+
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE", "0"))),
860+
856861
# Control the cache sized used by the xgrammar compiler. The default
857862
# of 512 MB should be enough for roughly 1000 JSON schemas.
858863
# It can be changed with this variable if needed for some reason.

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
from typing import Optional
3+
from typing import Any, Optional
44

55
import torch
66

@@ -255,28 +255,18 @@ def workspace_shapes(
255255
output = (num_experts, max_num_tokens * num_dispatchers, K)
256256
return (workspace13, workspace2, output, a.dtype)
257257

258-
def apply(
259-
self,
260-
output: torch.Tensor,
261-
hidden_states: torch.Tensor,
262-
w1: torch.Tensor,
263-
w2: torch.Tensor,
264-
topk_weights: torch.Tensor,
265-
topk_ids: torch.Tensor,
266-
activation: str,
267-
global_num_experts: int,
268-
expert_map: Optional[torch.Tensor],
269-
w1_scale: Optional[torch.Tensor],
270-
w2_scale: Optional[torch.Tensor],
271-
w1_zp: Optional[torch.Tensor],
272-
w2_zp: Optional[torch.Tensor],
273-
a1q_scale: Optional[torch.Tensor],
274-
a2_scale: Optional[torch.Tensor],
275-
workspace13: torch.Tensor,
276-
workspace2: torch.Tensor,
277-
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
278-
apply_router_weight_on_input: bool,
279-
):
258+
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
259+
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
260+
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
261+
expert_map: Optional[torch.Tensor],
262+
w1_scale: Optional[torch.Tensor],
263+
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
264+
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
265+
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
266+
workspace2: torch.Tensor,
267+
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
268+
apply_router_weight_on_input: bool,
269+
extra_expert_args: Optional[dict[str, Any]]):
280270
assert expert_tokens_meta is not None
281271
expert_num_tokens = expert_tokens_meta.expert_num_tokens
282272

vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
from typing import Optional
3+
from typing import Any, Optional
44

55
import torch
66

@@ -142,12 +142,13 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
142142
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
143143
workspace2: torch.Tensor,
144144
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
145-
apply_router_weight_on_input: bool):
145+
apply_router_weight_on_input: bool,
146+
extra_expert_args: Optional[dict[str, Any]]):
146147
experts = (self.batched_deep_gemm_experts
147148
if self.allow_deep_gemm else self.batched_triton_experts)
148149
assert experts is not None
149150
experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids,
150151
activation, global_num_experts, expert_map, w1_scale,
151152
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
152153
workspace2, expert_tokens_meta,
153-
apply_router_weight_on_input)
154+
apply_router_weight_on_input, extra_expert_args)

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from vllm.model_executor.layers.quantization.base_config import (
1616
QuantizationConfig)
1717
from vllm.utils import cdiv
18+
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
1819

1920
logger = init_logger(__name__)
2021

@@ -188,6 +189,11 @@ def use_deepep_ll_kernels(self):
188189
return (self.use_all2all_kernels
189190
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
190191

192+
@property
193+
def use_flashinfer_cutlass_kernels(self):
194+
return (envs.VLLM_USE_FLASHINFER_MOE
195+
and has_flashinfer_cutlass_fused_moe())
196+
191197
@staticmethod
192198
def make(tp_size_: int, dp_size_: int,
193199
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
@@ -392,6 +398,10 @@ def use_deepep_ht_kernels(self):
392398
def use_deepep_ll_kernels(self):
393399
return self.moe_parallel_config.use_deepep_ll_kernels
394400

401+
@property
402+
def use_flashinfer_cutlass_kernels(self):
403+
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
404+
395405
@staticmethod
396406
def make(
397407
num_experts: int,
@@ -435,6 +445,12 @@ def make(
435445
if quant_dtype is None and isinstance(quant_config, Fp8Config):
436446
quant_dtype = torch.float8_e4m3fn
437447

448+
from vllm.model_executor.layers.quantization.modelopt import (
449+
ModelOptNvFp4Config)
450+
if quant_dtype is None and isinstance(quant_config,
451+
ModelOptNvFp4Config):
452+
quant_dtype = torch.uint8
453+
438454
if weight_quant is not None:
439455
per_out_ch_quant = (
440456
weight_quant.strategy == QuantizationStrategy.CHANNEL)

0 commit comments

Comments
 (0)