From 1749de8d36fd5cc725e8d5b1a498ca87882cffdd Mon Sep 17 00:00:00 2001 From: "Tuan M. Hoang-Trong" Date: Fri, 11 Jul 2025 15:28:17 -0400 Subject: [PATCH 1/3] fix two issues: using metadata for causal-conv1d and init_states in v0 vLLM Signed-off-by: Tuan M. Hoang-Trong --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 4ca8e6b97fce..4e473bda1fe2 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -583,8 +583,8 @@ def forward_cuda( x = hidden_states_B_C_p.transpose( 0, 1) # this is the form that causal-conv see if mamba2_metadata.cu_seqlen is None: - mamba2_metadata = update_metadata( - x, attn_metadata.query_start_loc, mamba2_metadata) + mamba2_metadata = update_metadata(x, query_start_loc_p, + mamba2_metadata) hidden_states_B_C_p = causal_conv1d_fn( x, conv_weights, @@ -593,6 +593,7 @@ def forward_cuda( conv_states=conv_state, has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, + metadata=mamba2_metadata, query_start_loc=query_start_loc_p).transpose( 0, 1)[:num_prefill_tokens] @@ -603,9 +604,14 @@ def forward_cuda( initial_states = None if (has_initial_states_p is not None and prep_initial_states): # making a copy of the states - initial_states = torch.where( - has_initial_states_p[:, None, None, None], - ssm_state[state_indices_tensor_p], 0) + if envs.VLLM_USE_V1: + initial_states = torch.where( + has_initial_states_p[:, None, None, None], + ssm_state[state_indices_tensor_p], 0) + else: + initial_states = torch.where( + has_initial_states_p[:num_prefills, None, None, None], + ssm_state[state_indices_tensor_p], 0) scan_output, varlen_state = mamba_chunk_scan_combined( hidden_states_p.view(1, num_prefill_tokens, From 0a2068e91f0d8dd4992064195e4bf696c0b04eab Mon Sep 17 00:00:00 2001 From: "Tuan M. Hoang-Trong" Date: Mon, 14 Jul 2025 16:25:06 -0400 Subject: [PATCH 2/3] use torch.empty_like rather torch.zeros_like to save some overhead Signed-off-by: Tuan M. Hoang-Trong --- vllm/model_executor/layers/mamba/ops/causal_conv1d.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 6793f6def2b7..fe59c302c3f9 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -55,7 +55,6 @@ def _causal_conv1d_fwd_kernel( # continuous batching IS_CONTINUOUS_BATCHING: tl.constexpr, USE_PAD_SLOT: tl.constexpr, NP2_STATELEN: tl.constexpr, - DECODE_SEQLEN: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): @@ -415,7 +414,7 @@ def causal_conv1d_fn( activation = "silu" args = None - out = torch.zeros_like(x) + out = torch.empty_like(x) if metadata is not None: cu_seqlen = metadata.cu_seqlen nums_dict = metadata.nums_dict @@ -606,7 +605,6 @@ def grid(META): IS_CONTINUOUS_BATCHING=cache_indices is not None, USE_PAD_SLOT=pad_slot_id is not None, NP2_STATELEN=np2_statelen, - DECODE_SEQLEN=1, #launch_cooperative_grid=True BLOCK_M=8, BLOCK_N=256, From cc2c6bd13d003a774fc3a947dfbbd429a24520fb Mon Sep 17 00:00:00 2001 From: "Tuan M. Hoang-Trong" Date: Mon, 14 Jul 2025 21:18:41 -0400 Subject: [PATCH 3/3] address potential indexing overflow in the causal-conv1d-update kernel Signed-off-by: Tuan M. Hoang-Trong --- vllm/model_executor/layers/mamba/ops/causal_conv1d.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 0f2843c6a8f7..b8d4bbc37105 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -663,7 +663,8 @@ def _causal_conv1d_update_kernel( if IS_CONTINUOUS_BATCHING: # mask = idx_seq < batch - conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq) + conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to( + tl.int64) else: conv_state_batch_coord = idx_seq if USE_PAD_SLOT: # noqa