Skip to content

Commit 5cf2dae

Browse files
varun-sundar-rabindranathVarun
andauthored
[Misc] Fixes and Optimizations for DeepEP + DeepGEMM combination. (#19298)
Signed-off-by: Varun <vsundarr@redhat.com> Co-authored-by: Varun <vsundarr@redhat.com>
1 parent b808919 commit 5cf2dae

File tree

8 files changed

+98
-36
lines changed

8 files changed

+98
-36
lines changed

tests/kernels/moe/test_pplx_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor,
274274
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
275275
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
276276

277-
b_a, b_a_scale, expert_num_tokens = prepare_finalize.prepare(
277+
b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
278278
a_chunk,
279279
None,
280280
None,

vllm/distributed/device_communicators/all2all.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -233,16 +233,11 @@ def _make_all2all_kwargs(
233233
# Defaults for internode and intranode are taken from DeepEP tests.
234234
num_nvl_bytes = 1024 * 1024 * 1024
235235
num_qps_per_rank = num_local_experts
236-
num_rdma_bytes = None
237-
238-
if self.internode:
239-
num_rdma_bytes = 1024 * 1024 * 1024
240-
else:
241-
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
242-
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
243-
hidden=token_hidden_size,
244-
num_ranks=num_ep_ranks,
245-
num_experts=num_global_experts)
236+
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
237+
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
238+
hidden=token_hidden_size,
239+
num_ranks=num_ep_ranks,
240+
num_experts=num_global_experts)
246241

247242
assert num_rdma_bytes is not None
248243
return dict(group=self.cpu_group,

vllm/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@
110110
VLLM_DP_SIZE: int = 1
111111
VLLM_DP_MASTER_IP: str = ""
112112
VLLM_DP_MASTER_PORT: int = 0
113+
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
113114
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
114115
VLLM_V0_USE_OUTLINES_CACHE: bool = False
115116
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
@@ -761,6 +762,10 @@ def get_vllm_port() -> Optional[int]:
761762
"VLLM_DP_MASTER_PORT":
762763
lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")),
763764

765+
# Randomize inputs during dummy runs when using Data Parallel
766+
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS":
767+
lambda: os.environ.get("VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0") == "1",
768+
764769
# Whether to use S3 path for model loading in CI via RunAI Streamer
765770
"VLLM_CI_USE_S3":
766771
lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1",

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,13 @@ def workspace_shapes(
8080
topk: int,
8181
num_experts: int,
8282
) -> tuple[int, int, torch.dtype]:
83+
8384
block_m = self.block_shape[0]
8485
M_sum = (M * topk) + num_experts * (block_m - 1)
8586
M_sum = round_up(M_sum, block_m)
8687
workspace1 = M_sum * max(N * 2, K)
87-
workspace2 = M_sum * N
88+
workspace2 = M_sum * max(N, K)
89+
8890
return (workspace1, workspace2, a.dtype)
8991

9092
def apply(
@@ -135,26 +137,31 @@ def apply(
135137

136138
# Note: M_sum is different than the pre-permuted shape of a1q.
137139
M_sum = a1q.size(0)
138-
workspace1 = _resize_cache(workspace13, (M_sum, N))
139-
workspace2 = _resize_cache(workspace2, (M_sum, N // 2))
140-
workspace3 = _resize_cache(workspace13, (M_sum, K))
140+
141+
mm1_out = _resize_cache(workspace13, (M_sum, N))
142+
act_out = _resize_cache(workspace2, (M_sum, N // 2))
143+
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
144+
(M_sum, N // 2))
145+
mm2_out = _resize_cache(workspace2, (M_sum, K))
146+
out = _resize_cache(workspace13, (inv_perm.size(0), K))
141147

142148
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
143-
(a1q, a1q_scale), (w1, w1_scale), workspace1, expert_ids)
149+
(a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids)
144150

145-
self.activation(activation, workspace2, workspace1.view(-1, N))
151+
self.activation(activation, act_out, mm1_out.view(-1, N))
146152

147153
a2q_scale: Optional[torch.Tensor] = None
148-
a2q, a2q_scale = per_token_group_quant_fp8(workspace2,
154+
a2q, a2q_scale = per_token_group_quant_fp8(act_out,
149155
self.block_shape[1],
150-
column_major_scales=True)
156+
column_major_scales=True,
157+
out_q=quant_out)
151158

152159
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
153-
(a2q, a2q_scale), (w2, w2_scale), workspace3, expert_ids)
160+
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids)
154161

155-
workspace3 = workspace3[inv_perm, ...]
162+
torch.index_select(mm2_out, 0, inv_perm, out=out)
156163

157-
return workspace3
164+
return out
158165

159166

160167
def deep_gemm_moe_fp8(

vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66

77
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
8+
from vllm import _custom_ops as ops
89
from vllm.model_executor.layers.fused_moe.utils import (
910
moe_kernel_quantize_input)
1011

@@ -193,20 +194,23 @@ def _apply_weights_and_reduce(self, num_tokens: int,
193194
apply_router_weight_on_input: bool,
194195
output_dtype: torch.dtype):
195196

197+
hidden_dim = fused_expert_output.size(-1)
196198
if fused_expert_output.ndim == 2:
197-
hidden_dim = fused_expert_output.size(-1)
198199
fused_expert_output = fused_expert_output.view(
199200
num_tokens, -1, hidden_dim)
200201

201202
if not apply_router_weight_on_input:
202203
# The DeepEP combine kernels don't do the topk weight
203204
# multiplication. We multiply the weights locally.
204-
fused_expert_output = fused_expert_output.to(torch.float32)
205-
fused_expert_output = fused_expert_output * topk_weights.view(
206-
fused_expert_output.size(0), -1, 1)
207-
fused_expert_output = fused_expert_output.to(output_dtype)
205+
m_x_topk = fused_expert_output.size(0)
206+
fused_expert_output.mul_(topk_weights.view(m_x_topk, -1, 1))
208207

209-
return fused_expert_output.sum(dim=1).to(output_dtype)
208+
out = torch.empty((num_tokens, hidden_dim),
209+
device=fused_expert_output.device,
210+
dtype=output_dtype)
211+
ops.moe_sum(fused_expert_output, out)
212+
213+
return out
210214

211215
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
212216
topk_weights: torch.Tensor, topk_ids: torch.Tensor,

vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def _moe_permute(
1818
expert_map: Optional[torch.Tensor],
1919
block_m: int,
2020
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
21-
Optional[torch.Tensor]]:
21+
torch.Tensor]:
2222
"""
2323
Determine the sorted_token_ids, expert_ids for the given problem size.
2424
Permute the hidden states and scales according to `sorted_token_ids`.

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,13 @@ def _per_token_group_quant_fp8(
234234
row = g_id // groups_per_row
235235
row_g_id = g_id % groups_per_row
236236

237-
y_ptr += (row * y_row_stride) + (row_g_id * group_size)
238-
y_q_ptr += g_id * group_size
237+
# Ensure offset calculations use int64 to prevent overflow
238+
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
239+
group_size)
240+
y_ptr += y_ptr_offset
241+
242+
y_q_ptr_offset = g_id.to(tl.int64) * group_size
243+
y_q_ptr += y_q_ptr_offset
239244
y_s_ptr += g_id
240245

241246
cols = tl.arange(0, BLOCK) # N <= BLOCK
@@ -282,15 +287,23 @@ def _per_token_group_quant_fp8_colmajor(
282287
row = g_id // groups_per_row
283288
row_g_id = g_id % groups_per_row
284289

285-
y_ptr += (row * y_row_stride) + (row_g_id * group_size)
286-
y_q_ptr += g_id * group_size
290+
# Ensure offset calculations use int64 to prevent overflow
291+
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
292+
group_size)
293+
y_ptr += y_ptr_offset
294+
295+
y_q_ptr_offset = g_id.to(tl.int64) * group_size
296+
y_q_ptr += y_q_ptr_offset
287297

288298
# Convert g_id the flattened block coordinate to 2D so we can index
289299
# into the output y_scales matrix
290300
blocks_per_row = y_num_columns // group_size
291301
scale_col = g_id % blocks_per_row
292302
scale_row = g_id // blocks_per_row
293-
y_s_ptr += scale_col * y_s_col_stride + scale_row
303+
# Ensure offset calculation uses int64 for y_s_ptr
304+
y_s_ptr_offset = (scale_col.to(tl.int64) * y_s_col_stride) + scale_row.to(
305+
tl.int64)
306+
y_s_ptr += y_s_ptr_offset
294307

295308
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
296309
mask = cols < group_size
@@ -311,6 +324,7 @@ def per_token_group_quant_fp8(
311324
eps: float = 1e-10,
312325
dtype: Optional[torch.dtype] = None,
313326
column_major_scales: bool = False,
327+
out_q: Optional[torch.Tensor] = None,
314328
) -> tuple[torch.Tensor, torch.Tensor]:
315329
"""Function to perform per-token-group quantization on an input tensor `x`.
316330
It converts the tensor values into signed float8 values and returns the
@@ -321,6 +335,8 @@ def per_token_group_quant_fp8(
321335
eps: The minimum to avoid dividing zero.
322336
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
323337
is supported for now.
338+
column_major_scales: Outputs scales in column major.
339+
out_q: Optional output tensor. If not provided, function will create.
324340
Returns:
325341
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
326342
scaling factor for quantization.
@@ -335,7 +351,11 @@ def per_token_group_quant_fp8(
335351
fp8_min = finfo.min
336352
fp8_max = finfo.max
337353

338-
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
354+
assert out_q is None or out_q.shape == x.shape
355+
x_q = out_q
356+
if x_q is None:
357+
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
358+
339359
M = x.numel() // group_size
340360
N = group_size
341361
if column_major_scales:

vllm/v1/worker/gpu_model_runner.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
import gc
66
import time
77
import weakref
8+
from contextlib import contextmanager
89
from typing import TYPE_CHECKING, Any, Optional, Union
910

1011
import numpy as np
1112
import torch
1213
import torch.distributed
1314
import torch.nn as nn
1415

16+
import vllm.envs as envs
1517
from vllm.attention import AttentionType, get_attn_backend
1618
from vllm.attention.backends.abstract import (AttentionBackend,
1719
AttentionMetadataBuilder)
@@ -1727,6 +1729,35 @@ def _get_prompt_logprobs_dict(
17271729

17281730
return prompt_logprobs_dict
17291731

1732+
@contextmanager
1733+
def maybe_randomize_inputs(self, input_ids: torch.Tensor):
1734+
"""
1735+
Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
1736+
This is to help balance expert-selection
1737+
- during profile_run
1738+
- during DP rank dummy run
1739+
"""
1740+
dp_size = self.vllm_config.parallel_config.data_parallel_size
1741+
randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1
1742+
if not randomize_inputs:
1743+
yield
1744+
else:
1745+
import functools
1746+
1747+
@functools.cache
1748+
def rand_input_ids() -> torch.Tensor:
1749+
return torch.randint_like(
1750+
self.input_ids,
1751+
low=0,
1752+
high=self.model_config.get_vocab_size(),
1753+
dtype=input_ids.dtype)
1754+
1755+
logger.debug("Randomizing dummy data for DP Rank")
1756+
input_ids.copy_(rand_input_ids()[:input_ids.size(0)],
1757+
non_blocking=True)
1758+
yield
1759+
input_ids.fill_(0)
1760+
17301761
@torch.inference_mode()
17311762
def _dummy_run(
17321763
self,
@@ -1807,7 +1838,7 @@ def _dummy_run(
18071838
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
18081839
num_tokens, None, False)
18091840

1810-
with set_forward_context(
1841+
with self.maybe_randomize_inputs(input_ids), set_forward_context(
18111842
attn_metadata,
18121843
self.vllm_config,
18131844
num_tokens=num_tokens,

0 commit comments

Comments
 (0)