Skip to content

Commit f2cd532

Browse files
thoangtrvntmhoangt
authored andcommitted
[BugFix] fix 3 issues: (1) using metadata for causal-conv1d, (2) indexing overflow in v1 vLLM, and (3) init_states in v0 (vllm-project#20838)
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com> Co-authored-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com> Signed-off-by: Himanshu Jaju <hj@mistral.ai>
1 parent 25ae593 commit f2cd532

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -573,8 +573,8 @@ def forward_cuda(
573573
x = hidden_states_B_C_p.transpose(
574574
0, 1) # this is the form that causal-conv see
575575
if mamba2_metadata.cu_seqlen is None:
576-
mamba2_metadata = update_metadata(
577-
x, attn_metadata.query_start_loc, mamba2_metadata)
576+
mamba2_metadata = update_metadata(x, query_start_loc_p,
577+
mamba2_metadata)
578578
hidden_states_B_C_p = causal_conv1d_fn(
579579
x,
580580
conv_weights,
@@ -583,6 +583,7 @@ def forward_cuda(
583583
conv_states=conv_state,
584584
has_initial_state=has_initial_states_p,
585585
cache_indices=state_indices_tensor_p,
586+
metadata=mamba2_metadata,
586587
query_start_loc=query_start_loc_p).transpose(
587588
0, 1)[:num_prefill_tokens]
588589

@@ -593,9 +594,14 @@ def forward_cuda(
593594
initial_states = None
594595
if (has_initial_states_p is not None and prep_initial_states):
595596
# making a copy of the states
596-
initial_states = torch.where(
597-
has_initial_states_p[:, None, None, None],
598-
ssm_state[state_indices_tensor_p], 0)
597+
if envs.VLLM_USE_V1:
598+
initial_states = torch.where(
599+
has_initial_states_p[:, None, None, None],
600+
ssm_state[state_indices_tensor_p], 0)
601+
else:
602+
initial_states = torch.where(
603+
has_initial_states_p[:num_prefills, None, None, None],
604+
ssm_state[state_indices_tensor_p], 0)
599605

600606
scan_output, varlen_state = mamba_chunk_scan_combined(
601607
hidden_states_p.view(1, num_prefill_tokens,

vllm/model_executor/layers/mamba/ops/causal_conv1d.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def _causal_conv1d_fwd_kernel( # continuous batching
5555
IS_CONTINUOUS_BATCHING: tl.constexpr,
5656
USE_PAD_SLOT: tl.constexpr,
5757
NP2_STATELEN: tl.constexpr,
58-
DECODE_SEQLEN: tl.constexpr,
5958
BLOCK_M: tl.constexpr,
6059
BLOCK_N: tl.constexpr,
6160
):
@@ -416,7 +415,7 @@ def causal_conv1d_fn(
416415
activation = "silu"
417416

418417
args = None
419-
out = torch.zeros_like(x)
418+
out = torch.empty_like(x)
420419
if metadata is not None:
421420
cu_seqlen = metadata.cu_seqlen
422421
nums_dict = metadata.nums_dict
@@ -607,7 +606,6 @@ def grid(META):
607606
IS_CONTINUOUS_BATCHING=cache_indices is not None,
608607
USE_PAD_SLOT=pad_slot_id is not None,
609608
NP2_STATELEN=np2_statelen,
610-
DECODE_SEQLEN=1,
611609
#launch_cooperative_grid=True
612610
BLOCK_M=8,
613611
BLOCK_N=256,
@@ -665,7 +663,8 @@ def _causal_conv1d_update_kernel(
665663

666664
if IS_CONTINUOUS_BATCHING:
667665
# mask = idx_seq < batch
668-
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq)
666+
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(
667+
tl.int64)
669668
else:
670669
conv_state_batch_coord = idx_seq
671670
if USE_PAD_SLOT: # noqa

0 commit comments

Comments
 (0)