Skip to content

[BugFix] fix 3 issues: (1) using metadata for causal-conv1d, (2) indexing overflow in v1 vLLM, and (3) init_states in v0 #20838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,8 @@ def forward_cuda(
x = hidden_states_B_C_p.transpose(
0, 1) # this is the form that causal-conv see
if mamba2_metadata.cu_seqlen is None:
mamba2_metadata = update_metadata(
x, attn_metadata.query_start_loc, mamba2_metadata)
mamba2_metadata = update_metadata(x, query_start_loc_p,
mamba2_metadata)
hidden_states_B_C_p = causal_conv1d_fn(
x,
conv_weights,
Expand All @@ -593,6 +593,7 @@ def forward_cuda(
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
metadata=mamba2_metadata,
query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens]

Expand All @@ -603,9 +604,14 @@ def forward_cuda(
initial_states = None
if (has_initial_states_p is not None and prep_initial_states):
# making a copy of the states
initial_states = torch.where(
has_initial_states_p[:, None, None, None],
ssm_state[state_indices_tensor_p], 0)
if envs.VLLM_USE_V1:
initial_states = torch.where(
has_initial_states_p[:, None, None, None],
ssm_state[state_indices_tensor_p], 0)
else:
initial_states = torch.where(
has_initial_states_p[:num_prefills, None, None, None],
ssm_state[state_indices_tensor_p], 0)

scan_output, varlen_state = mamba_chunk_scan_combined(
hidden_states_p.view(1, num_prefill_tokens,
Expand Down
7 changes: 3 additions & 4 deletions vllm/model_executor/layers/mamba/ops/causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def _causal_conv1d_fwd_kernel( # continuous batching
IS_CONTINUOUS_BATCHING: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
NP2_STATELEN: tl.constexpr,
DECODE_SEQLEN: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
Expand Down Expand Up @@ -416,7 +415,7 @@ def causal_conv1d_fn(
activation = "silu"

args = None
out = torch.zeros_like(x)
out = torch.empty_like(x)
if metadata is not None:
cu_seqlen = metadata.cu_seqlen
nums_dict = metadata.nums_dict
Expand Down Expand Up @@ -607,7 +606,6 @@ def grid(META):
IS_CONTINUOUS_BATCHING=cache_indices is not None,
USE_PAD_SLOT=pad_slot_id is not None,
NP2_STATELEN=np2_statelen,
DECODE_SEQLEN=1,
#launch_cooperative_grid=True
BLOCK_M=8,
BLOCK_N=256,
Expand Down Expand Up @@ -665,7 +663,8 @@ def _causal_conv1d_update_kernel(

if IS_CONTINUOUS_BATCHING:
# mask = idx_seq < batch
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq)
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(
tl.int64)
else:
conv_state_batch_coord = idx_seq
if USE_PAD_SLOT: # noqa
Expand Down