Skip to content

[Performance] Performance improvements in non-blockwise fp8 CUTLASS MoE #20762

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def bench_run(
a, score, topk, renormalize=False
)

ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)

def run_triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
Expand Down Expand Up @@ -111,6 +116,10 @@ def run_cutlass_moe(
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
Expand All @@ -125,6 +134,10 @@ def run_cutlass_moe(
topk_ids,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
per_act_token,
a1_scale=None,
)
Expand All @@ -136,6 +149,10 @@ def run_cutlass_from_graph(
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
):
Expand All @@ -150,6 +167,10 @@ def run_cutlass_from_graph(
topk_ids,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
per_act_token,
a1_scale=None,
)
Expand Down Expand Up @@ -194,6 +215,10 @@ def replay_graph(graph, num_repeats):
w2_q,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights,
topk_ids,
)
Expand Down Expand Up @@ -231,6 +256,10 @@ def replay_graph(graph, num_repeats):
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"per_act_token": per_act_token,
"ab_strides1": ab_strides1,
"ab_strides2": ab_strides2,
"c_strides1": c_strides1,
"c_strides2": c_strides2,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
Expand Down Expand Up @@ -289,6 +318,10 @@ def replay_graph(graph, num_repeats):
w2_q,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights,
topk_ids,
per_act_token,
Expand All @@ -297,7 +330,7 @@ def replay_graph(graph, num_repeats):

results.append(
benchmark.Timer(
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
Expand Down
53 changes: 42 additions & 11 deletions csrc/moe/moe_permute_unpermute_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,30 @@ __global__ void shuffleInputRowsKernel(const T* input,
}
}

template <typename T>
__global__ void shuffleInputRowsKernelSlow(const T* input,
const int32_t* dst2src_map,
T* output, int64_t num_src_rows,
int64_t num_dst_rows,
int64_t num_cols) {
int64_t dest_row_idx = blockIdx.x;
int64_t const source_row_idx = dst2src_map[dest_row_idx];

if (blockIdx.x < num_dst_rows) {
// Duplicate and permute rows
auto const* source_row_ptr = input + source_row_idx * num_cols;
auto* dest_row_ptr = output + dest_row_idx * num_cols;

int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;

for (int elem_index = start_offset; elem_index < num_cols;
elem_index += stride) {
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
}
}
}

void shuffle_rows(const torch::Tensor& input_tensor,
const torch::Tensor& dst2src_map,
torch::Tensor& output_tensor) {
Expand All @@ -173,17 +197,24 @@ void shuffle_rows(const torch::Tensor& input_tensor,
int64_t const num_src_rows = input_tensor.size(0);
int64_t const num_cols = input_tensor.size(1);

TORCH_CHECK(!(num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)),
"num_cols must be divisible by 128 / "
"sizeof(input_tensor.scalar_type()) / 8");

MOE_DISPATCH(input_tensor.scalar_type(), [&] {
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
});
if (num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)) {
// use slow kernel if num_cols can't be aligned to 128 bits
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
shuffleInputRowsKernelSlow<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
});
} else {
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
});
}
Comment on lines +200 to +217
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The MOE_DISPATCH macro is duplicated in both the if and else branches. This can be refactored to have a single MOE_DISPATCH call with the conditional logic inside the lambda to improve code readability and maintainability by reducing duplication.

MOE_DISPATCH(input_tensor.scalar_type(), [&] {
  if (num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)) {
    // use slow kernel if num_cols can't be aligned to 128 bits
    shuffleInputRowsKernelSlow<scalar_t><<<blocks, threads, 0, stream>>>(
        reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
        dst2src_map.data_ptr<int32_t>(),
        reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
        num_dest_rows, num_cols);
  } else {
    shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
        reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
        dst2src_map.data_ptr<int32_t>(),
        reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
        num_dest_rows, num_cols);
  }
});

}

#else
Expand Down
14 changes: 12 additions & 2 deletions tests/kernels/moe/test_cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'topk_ids': topk_ids,
'w1_scale': moe_tensors.w1_scale,
'w2_scale': moe_tensors.w2_scale,
'ab_strides1': moe_tensors.ab_strides1,
'ab_strides2': moe_tensors.ab_strides2,
'c_strides1': moe_tensors.c_strides1,
'c_strides2': moe_tensors.c_strides2,
'per_act_token': per_act_token,
'a1_scale': None #moe_tensors.a_scale
}
Expand Down Expand Up @@ -439,6 +443,11 @@ def test_run_cutlass_moe_fp8(
expert_map[start:end] = list(range(num_local_experts))
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")

ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)

activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
torch.float8_e4m3fn,
Expand All @@ -447,8 +456,9 @@ def test_run_cutlass_moe_fp8(
func = lambda output: run_cutlass_moe_fp8(
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
a1q_scale, None, workspace13, workspace2, None, mt.a.dtype,
per_act_token, per_out_channel, False)
a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2,
workspace13, workspace2, None, mt.a.dtype, per_act_token,
per_out_channel, False)

workspace13.random_()
output_random_workspace = torch.empty(output_shape,
Expand Down
22 changes: 22 additions & 0 deletions tests/kernels/moe/test_pplx_cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def pplx_cutlass_moe(
assert torch.cuda.current_device() == pgi.local_rank

num_tokens, hidden_dim = a.shape
intermediate_dim = w2.shape[2]
num_experts = w1.shape[0]
block_size = hidden_dim # TODO support more cases
device = pgi.device
Expand Down Expand Up @@ -123,10 +124,31 @@ def pplx_cutlass_moe(
num_local_experts=num_local_experts,
num_dispatchers=num_dispatchers)

ab_strides1 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)
ab_strides2 = torch.full((num_local_experts, ),
intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides1 = torch.full((num_local_experts, ),
2 * intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides2 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)

experts = CutlassExpertsFp8(num_local_experts,
out_dtype,
per_act_token,
per_out_ch,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
num_dispatchers=num_dispatchers,
use_batched_format=True)

Expand Down
62 changes: 39 additions & 23 deletions vllm/model_executor/layers/fused_moe/cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
_fp8_quantize,
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
_resize_cache)
from vllm.scalar_type import scalar_types

Expand All @@ -34,6 +33,10 @@ def run_cutlass_moe_fp8(
w2_scale: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
Expand Down Expand Up @@ -152,27 +155,11 @@ def run_cutlass_moe_fp8(
problem_sizes1, problem_sizes2, a_map,
c_map, global_num_experts, N, K)

a1q = _fp8_perm(a1q, a_map)
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
a1q = ops.shuffle_rows(a1q, a_map)
a1q_scale = (ops.shuffle_rows(a1q_scale, a_map)
if per_act_token else a1q_scale)
expert_offsets = expert_offsets[:-1]

ab_strides1 = torch.full((w1.size(0), ),
K,
device=device,
dtype=torch.int64)
c_strides1 = torch.full((w1.size(0), ),
2 * N,
device=device,
dtype=torch.int64)
ab_strides2 = torch.full((w1.size(0), ),
N,
device=device,
dtype=torch.int64)
c_strides2 = torch.full((w1.size(0), ),
K,
device=device,
dtype=torch.int64)

if use_batched_format:
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
Expand Down Expand Up @@ -209,7 +196,8 @@ def run_cutlass_moe_fp8(
else:
# We can't do this inplace because output may point to the same tensor
# as c3.
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
output.copy_(ops.shuffle_rows(c3, c_map).view(M * topk, K),
non_blocking=True)


# TODO (bnell): split class batched vs. non-batched?
Expand All @@ -222,6 +210,10 @@ def __init__(
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
num_dispatchers: Optional[int] = None,
use_batched_format: bool = False,
Expand All @@ -238,6 +230,10 @@ def __init__(
self.max_experts_per_worker = max_experts_per_worker
self.num_dispatchers = num_dispatchers
self.out_dtype = out_dtype
self.ab_strides1 = ab_strides1
self.ab_strides2 = ab_strides2
self.c_strides1 = c_strides1
self.c_strides2 = c_strides2
self.use_batched_format = use_batched_format

@property
Expand Down Expand Up @@ -324,7 +320,8 @@ def apply(
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2, expert_num_tokens,
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
self.c_strides2, workspace13, workspace2, expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant,
self.use_batched_format)
Expand All @@ -338,6 +335,10 @@ def cutlass_moe_fp8(
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
per_act_token: Optional[bool] = None,
activation: str = "silu",
a1_scale: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -365,6 +366,17 @@ def cutlass_moe_fp8(
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
Shape: [num_experts]
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
Shape: [num_experts]
- c_strides1 (torch.Tensor): The output strides for the first gemm.
Shape: [num_experts]
- c_strides2 (torch.Tensor): The output strides for the second gemm.
Shape: [num_experts]
- per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor.
- activation (str): The activation function to use.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
Expand Down Expand Up @@ -397,6 +409,10 @@ def cutlass_moe_fp8(
out_dtype=a.dtype,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
use_batched_format=False,
),
)
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def prepare(
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=topk_ids,
indices=topk_ids.view(dtype=torch.uint32),
bound_m=bound_m,
)

Expand Down Expand Up @@ -249,7 +249,7 @@ def finalize(
topk_weights = torch.ones_like(topk_weights)

self.a2a.combine(out_tokens=output,
indices=topk_ids,
indices=topk_ids.view(dtype=torch.uint32),
weights=topk_weights,
expert_y=fused_expert_output,
bound_m=bound_m)
Loading