From 951dc4ede2f334434c297cf6ea8f16c38e5f17a1 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 10 Jul 2025 14:44:50 +0000 Subject: [PATCH 1/5] Performance improvements in non-blockwise CUTLASS MoE Signed-off-by: ElizaWszola --- .../kernels/benchmark_grouped_gemm_cutlass.py | 35 ++++++++- csrc/moe/moe_permute_unpermute_op.cu | 53 ++++++++++--- tests/kernels/moe/test_cutlass_moe.py | 14 +++- tests/kernels/moe/test_pplx_cutlass_moe.py | 22 ++++++ .../layers/fused_moe/cutlass_moe.py | 62 +++++++++------ .../compressed_tensors_moe.py | 77 +++++++++++++++---- 6 files changed, 209 insertions(+), 54 deletions(-) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index 1d4e730f99a..a6b42406b5c 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -80,6 +80,11 @@ def bench_run( a, score, topk, renormalize=False ) + ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) + def run_triton_moe( a: torch.Tensor, w1: torch.Tensor, @@ -111,6 +116,10 @@ def run_cutlass_moe( w2: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, per_act_token: bool, @@ -125,6 +134,10 @@ def run_cutlass_moe( topk_ids, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, per_act_token, a1_scale=None, ) @@ -136,6 +149,10 @@ def run_cutlass_from_graph( w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, ): @@ -150,6 +167,10 @@ def run_cutlass_from_graph( topk_ids, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, per_act_token, a1_scale=None, ) @@ -194,6 +215,10 @@ def replay_graph(graph, num_repeats): w2_q, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, topk_weights, topk_ids, ) @@ -231,6 +256,10 @@ def replay_graph(graph, num_repeats): "w1_scale": w1_scale, "w2_scale": w2_scale, "per_act_token": per_act_token, + "ab_strides1": ab_strides1, + "ab_strides2": ab_strides2, + "c_strides1": c_strides1, + "c_strides2": c_strides2, # cuda graph params "cutlass_graph": cutlass_graph, "triton_graph": triton_graph, @@ -289,6 +318,10 @@ def replay_graph(graph, num_repeats): w2_q, w1_scale, w2_scale, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, topk_weights, topk_ids, per_act_token, @@ -297,7 +330,7 @@ def replay_graph(graph, num_repeats): results.append( benchmark.Timer( - stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501 + stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index a77471a7f20..13aecd8007a 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -160,6 +160,30 @@ __global__ void shuffleInputRowsKernel(const T* input, } } +template +__global__ void shuffleInputRowsKernelSlow(const T* input, + const int32_t* dst2src_map, + T* output, int64_t num_src_rows, + int64_t num_dst_rows, + int64_t num_cols) { + int64_t dest_row_idx = blockIdx.x; + int64_t const source_row_idx = dst2src_map[dest_row_idx]; + + if (blockIdx.x < num_dst_rows) { + // Duplicate and permute rows + auto const* source_row_ptr = input + source_row_idx * num_cols; + auto* dest_row_ptr = output + dest_row_idx * num_cols; + + int64_t const start_offset = threadIdx.x; + int64_t const stride = blockDim.x; + + for (int elem_index = start_offset; elem_index < num_cols; + elem_index += stride) { + dest_row_ptr[elem_index] = source_row_ptr[elem_index]; + } + } +} + void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor) { @@ -173,17 +197,24 @@ void shuffle_rows(const torch::Tensor& input_tensor, int64_t const num_src_rows = input_tensor.size(0); int64_t const num_cols = input_tensor.size(1); - TORCH_CHECK(!(num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)), - "num_cols must be divisible by 128 / " - "sizeof(input_tensor.scalar_type()) / 8"); - - MOE_DISPATCH(input_tensor.scalar_type(), [&] { - shuffleInputRowsKernel<<>>( - reinterpret_cast(input_tensor.data_ptr()), - dst2src_map.data_ptr(), - reinterpret_cast(output_tensor.data_ptr()), num_src_rows, - num_dest_rows, num_cols); - }); + if (num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)) { + // use slow kernel if num_cols can't be aligned to 128 bits + MOE_DISPATCH(input_tensor.scalar_type(), [&] { + shuffleInputRowsKernelSlow<<>>( + reinterpret_cast(input_tensor.data_ptr()), + dst2src_map.data_ptr(), + reinterpret_cast(output_tensor.data_ptr()), num_src_rows, + num_dest_rows, num_cols); + }); + } else { + MOE_DISPATCH(input_tensor.scalar_type(), [&] { + shuffleInputRowsKernel<<>>( + reinterpret_cast(input_tensor.data_ptr()), + dst2src_map.data_ptr(), + reinterpret_cast(output_tensor.data_ptr()), num_src_rows, + num_dest_rows, num_cols); + }); + } } #else diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 5fac7166bc2..5fb49c2da4f 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -206,6 +206,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit, 'topk_ids': topk_ids, 'w1_scale': moe_tensors.w1_scale, 'w2_scale': moe_tensors.w2_scale, + 'ab_strides1': moe_tensors.ab_strides1, + 'ab_strides2': moe_tensors.ab_strides2, + 'c_strides1': moe_tensors.c_strides1, + 'c_strides2': moe_tensors.c_strides2, 'per_act_token': per_act_token, 'a1_scale': None #moe_tensors.a_scale } @@ -439,6 +443,11 @@ def test_run_cutlass_moe_fp8( expert_map[start:end] = list(range(num_local_experts)) expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda") + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + activation = lambda o, i: torch.ops._C.silu_and_mul(o, i) a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale, torch.float8_e4m3fn, @@ -447,8 +456,9 @@ def test_run_cutlass_moe_fp8( func = lambda output: run_cutlass_moe_fp8( output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation, global_num_experts, expert_map, mt.w1_scale, mt.w2_scale, - a1q_scale, None, workspace13, workspace2, None, mt.a.dtype, - per_act_token, per_out_channel, False) + a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2, + workspace13, workspace2, None, mt.a.dtype, per_act_token, + per_out_channel, False) workspace13.random_() output_random_workspace = torch.empty(output_shape, diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index e4f4a393dfd..77adc89ea9d 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -75,6 +75,7 @@ def pplx_cutlass_moe( assert torch.cuda.current_device() == pgi.local_rank num_tokens, hidden_dim = a.shape + intermediate_dim = w2.shape[2] num_experts = w1.shape[0] block_size = hidden_dim # TODO support more cases device = pgi.device @@ -123,10 +124,31 @@ def pplx_cutlass_moe( num_local_experts=num_local_experts, num_dispatchers=num_dispatchers) + ab_strides1 = torch.full((num_local_experts, ), + hidden_dim, + device="cuda", + dtype=torch.int64) + ab_strides2 = torch.full((num_local_experts, ), + intermediate_dim, + device="cuda", + dtype=torch.int64) + c_strides1 = torch.full((num_local_experts, ), + 2 * intermediate_dim, + device="cuda", + dtype=torch.int64) + c_strides2 = torch.full((num_local_experts, ), + hidden_dim, + device="cuda", + dtype=torch.int64) + experts = CutlassExpertsFp8(num_local_experts, out_dtype, per_act_token, per_out_ch, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, num_dispatchers=num_dispatchers, use_batched_format=True) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index d771a7a54cf..20e3e4582f9 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -11,8 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) -from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, - _fp8_quantize, +from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, _resize_cache) from vllm.scalar_type import scalar_types @@ -32,6 +31,10 @@ def run_cutlass_moe_fp8( w2_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, workspace13: torch.Tensor, workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], @@ -150,27 +153,11 @@ def run_cutlass_moe_fp8( problem_sizes1, problem_sizes2, a_map, c_map, global_num_experts, N, K) - a1q = _fp8_perm(a1q, a_map) - a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale + a1q = ops.shuffle_rows(a1q, a_map) + a1q_scale = (ops.shuffle_rows(a1q_scale, a_map) + if per_act_token else a1q_scale) expert_offsets = expert_offsets[:-1] - ab_strides1 = torch.full((w1.size(0), ), - K, - device=device, - dtype=torch.int64) - c_strides1 = torch.full((w1.size(0), ), - 2 * N, - device=device, - dtype=torch.int64) - ab_strides2 = torch.full((w1.size(0), ), - N, - device=device, - dtype=torch.int64) - c_strides2 = torch.full((w1.size(0), ), - K, - device=device, - dtype=torch.int64) - if use_batched_format: c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2)) c2 = _resize_cache(workspace2, (local_E * padded_M, N)) @@ -207,7 +194,8 @@ def run_cutlass_moe_fp8( else: # We can't do this inplace because output may point to the same tensor # as c3. - output.copy_(c3[c_map].view(M * topk, K), non_blocking=True) + output.copy_(ops.shuffle_rows(c3, c_map).view(M * topk, K), + non_blocking=True) # TODO (bnell): split class batched vs. non-batched? @@ -220,6 +208,10 @@ def __init__( out_dtype: Optional[torch.dtype], per_act_token_quant: bool, per_out_ch_quant: bool, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, block_shape: Optional[list[int]] = None, num_dispatchers: Optional[int] = None, use_batched_format: bool = False, @@ -236,6 +228,10 @@ def __init__( self.max_experts_per_worker = max_experts_per_worker self.num_dispatchers = num_dispatchers self.out_dtype = out_dtype + self.ab_strides1 = ab_strides1 + self.ab_strides2 = ab_strides2 + self.c_strides1 = c_strides1 + self.c_strides2 = c_strides2 self.use_batched_format = use_batched_format @property @@ -312,7 +308,8 @@ def apply( run_cutlass_moe_fp8( output, hidden_states, w1, w2, topk_ids, activation_callable, global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale, - a2_scale, workspace13, workspace2, expert_num_tokens, + a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1, + self.c_strides2, workspace13, workspace2, expert_num_tokens, self.out_dtype if self.out_dtype is not None else in_dtype, self.per_act_token_quant, self.per_out_ch_quant, self.use_batched_format) @@ -326,6 +323,10 @@ def cutlass_moe_fp8( topk_ids: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, per_act_token: Optional[bool] = None, activation: str = "silu", a1_scale: Optional[torch.Tensor] = None, @@ -353,6 +354,17 @@ def cutlass_moe_fp8( Shape: [num_experts] or [num_experts, 2N] - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. Shape: [num_experts] or [num_experts, K] + - ab_strides1 (torch.Tensor): The input/weight strides for the first gemm. + Shape: [num_experts] + - ab_strides2 (torch.Tensor): The input/weight strides for the second gemm. + Shape: [num_experts] + - c_strides1 (torch.Tensor): The output strides for the first gemm. + Shape: [num_experts] + - c_strides2 (torch.Tensor): The output strides for the second gemm. + Shape: [num_experts] + - per_act_token (Optional[bool]): Whether the scale is per-token or + per-tensor. + - activation (str): The activation function to use. - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. Shape: scalar or [M] - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to @@ -385,6 +397,10 @@ def cutlass_moe_fp8( out_dtype=a.dtype, per_act_token_quant=per_act_token, per_out_ch_quant=per_out_ch, + ab_strides1=ab_strides1, + ab_strides2=ab_strides2, + c_strides1=c_strides1, + c_strides2=c_strides2, use_batched_format=False, ), ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index ef67cc0eda4..0733d75b1c0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -737,10 +737,8 @@ def __init__( "For FP8 Fused MoE layer, we require either per tensor or " "channelwise, dynamic per token quantization.") - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp8) self.topk_indices_dtype = None - self.fused_experts = cutlass_moe_fp8 # type: ignore + self.fused_experts = None self.disable_expert_map = False def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -861,6 +859,24 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) + device = layer.w13_weight.device + self.ab_strides1 = torch.full((layer.local_num_experts, ), + layer.hidden_size, + device=device, + dtype=torch.int64) + self.ab_strides2 = torch.full((layer.local_num_experts, ), + layer.intermediate_size_per_partition, + device=device, + dtype=torch.int64) + self.c_strides1 = torch.full((layer.local_num_experts, ), + 2 * layer.intermediate_size_per_partition, + device=device, + dtype=torch.int64) + self.c_strides2 = torch.full((layer.local_num_experts, ), + layer.hidden_size, + device=device, + dtype=torch.int64) + def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, @@ -883,6 +899,10 @@ def select_gemm_impl( moe.in_dtype, self.input_quant.strategy == QuantizationStrategy.TOKEN, self.weight_quant.strategy == QuantizationStrategy.CHANNEL, + ab_strides1=self.ab_strides1, + ab_strides2=self.ab_strides2, + c_strides1=self.c_strides1, + c_strides2=self.c_strides2, num_dispatchers=num_dispatchers, use_batched_format=use_batched_format, ) @@ -933,20 +953,43 @@ def apply( indices_type=self.topk_indices_dtype, ) - return self.fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=None if self.disable_expert_map else expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - ) + if self.fused_experts is None: + # If no modular kernel is provided, use cutlass_moe_fp8 + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp8) + return cutlass_moe_fp8( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + ab_strides1=self.ab_strides1, + ab_strides2=self.ab_strides2, + c_strides1=self.c_strides1, + c_strides2=self.c_strides2, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) + else: + return self.fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): From 0f9914ea8a2765424c76f31778d0d4d22a7c0f5b Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 11 Jul 2025 05:00:45 +0000 Subject: [PATCH 2/5] Lint Signed-off-by: ElizaWszola --- .../quantization/compressed_tensors/compressed_tensors_moe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 5a141b34273..08eb15efc4c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -738,7 +738,7 @@ def __init__( "channelwise, dynamic per token quantization.") self.topk_indices_dtype = None - self.fused_experts = None + self.fused_experts = None # type: ignore self.disable_expert_map = False def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -986,7 +986,6 @@ def apply( layer.w2_weight, topk_weights, topk_ids, - per_act_token=per_act_token, activation=activation, global_num_experts=global_num_experts, expert_map=None if self.disable_expert_map else expert_map, From ff1424eed45a552a85f923b0eab4efd8895d5fef Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 11 Jul 2025 05:39:41 +0000 Subject: [PATCH 3/5] Make CompressedTensorsW8A8Fp8MoECutlassMethod aware of indices type Signed-off-by: ElizaWszola --- vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py | 2 +- .../quantization/compressed_tensors/compressed_tensors_moe.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index c84f28d0874..eb99d8ac6b7 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -78,7 +78,7 @@ def max_num_tokens_per_rank(self) -> Optional[int]: return self.max_num_tokens def topk_indices_dtype(self) -> Optional[torch.dtype]: - return torch.int32 + return torch.uint32 def num_dispatchers(self) -> int: return self.num_dispatchers_ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 08eb15efc4c..8090cd5b26e 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -949,7 +949,8 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) a1_scale = layer.w13_input_scale a2_scale = layer.w2_input_scale From 07793d48277fe06dddfcedeb2bef26d85b49c779 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 11 Jul 2025 15:47:36 +0000 Subject: [PATCH 4/5] PPLX type fix Signed-off-by: ElizaWszola --- .../layers/fused_moe/pplx_prepare_finalize.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index ba331610eef..9fb63569d37 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -83,7 +83,7 @@ def max_num_tokens_per_rank(self) -> Optional[int]: return self.max_num_tokens def topk_indices_dtype(self) -> Optional[torch.dtype]: - return torch.uint32 + return torch.int32 def num_dispatchers(self) -> int: return self.num_dispatchers_ @@ -204,7 +204,7 @@ def prepare( out_expert_x_scale=expert_x_scale, dp_x=a1q, dp_x_scale=a1q_scale, - indices=topk_ids, + indices=topk_ids.view(dtype=torch.uint32), bound_m=bound_m, ) @@ -249,7 +249,7 @@ def finalize( topk_weights = torch.ones_like(topk_weights) self.a2a.combine(out_tokens=output, - indices=topk_ids, + indices=topk_ids.view(dtype=torch.uint32), weights=topk_weights, expert_y=fused_expert_output, bound_m=bound_m) From 9f160333a2eafe72ef93824a40e9331477140819 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 15 Jul 2025 16:15:10 +0000 Subject: [PATCH 5/5] Make ab_strides1 and c_strides2 one tensor Signed-off-by: ElizaWszola --- .../compressed_tensors_moe.py | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 8090cd5b26e..9943346585e 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -860,10 +860,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: requires_grad=False) device = layer.w13_weight.device - self.ab_strides1 = torch.full((layer.local_num_experts, ), - layer.hidden_size, - device=device, - dtype=torch.int64) + # ab_strides1 and c_strides2 are the same + self.ab_strides1_c_strides2 = torch.full((layer.local_num_experts, ), + layer.hidden_size, + device=device, + dtype=torch.int64) self.ab_strides2 = torch.full((layer.local_num_experts, ), layer.intermediate_size_per_partition, device=device, @@ -872,10 +873,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: 2 * layer.intermediate_size_per_partition, device=device, dtype=torch.int64) - self.c_strides2 = torch.full((layer.local_num_experts, ), - layer.hidden_size, - device=device, - dtype=torch.int64) def select_gemm_impl( self, @@ -899,10 +896,10 @@ def select_gemm_impl( moe.in_dtype, self.input_quant.strategy == QuantizationStrategy.TOKEN, self.weight_quant.strategy == QuantizationStrategy.CHANNEL, - ab_strides1=self.ab_strides1, + ab_strides1=self.ab_strides1_c_strides2, ab_strides2=self.ab_strides2, c_strides1=self.c_strides1, - c_strides2=self.c_strides2, + c_strides2=self.ab_strides1_c_strides2, num_dispatchers=num_dispatchers, use_batched_format=use_batched_format, ) @@ -973,10 +970,10 @@ def apply( expert_map=None if self.disable_expert_map else expert_map, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, - ab_strides1=self.ab_strides1, + ab_strides1=self.ab_strides1_c_strides2, ab_strides2=self.ab_strides2, c_strides1=self.c_strides1, - c_strides2=self.c_strides2, + c_strides2=self.ab_strides1_c_strides2, a1_scale=a1_scale, a2_scale=a2_scale, )