diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 4ca8e6b97fce..4e473bda1fe2 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -583,8 +583,8 @@ def forward_cuda( x = hidden_states_B_C_p.transpose( 0, 1) # this is the form that causal-conv see if mamba2_metadata.cu_seqlen is None: - mamba2_metadata = update_metadata( - x, attn_metadata.query_start_loc, mamba2_metadata) + mamba2_metadata = update_metadata(x, query_start_loc_p, + mamba2_metadata) hidden_states_B_C_p = causal_conv1d_fn( x, conv_weights, @@ -593,6 +593,7 @@ def forward_cuda( conv_states=conv_state, has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, + metadata=mamba2_metadata, query_start_loc=query_start_loc_p).transpose( 0, 1)[:num_prefill_tokens] @@ -603,9 +604,14 @@ def forward_cuda( initial_states = None if (has_initial_states_p is not None and prep_initial_states): # making a copy of the states - initial_states = torch.where( - has_initial_states_p[:, None, None, None], - ssm_state[state_indices_tensor_p], 0) + if envs.VLLM_USE_V1: + initial_states = torch.where( + has_initial_states_p[:, None, None, None], + ssm_state[state_indices_tensor_p], 0) + else: + initial_states = torch.where( + has_initial_states_p[:num_prefills, None, None, None], + ssm_state[state_indices_tensor_p], 0) scan_output, varlen_state = mamba_chunk_scan_combined( hidden_states_p.view(1, num_prefill_tokens, diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index a8bd0067bf45..b8d4bbc37105 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -55,7 +55,6 @@ def _causal_conv1d_fwd_kernel( # continuous batching IS_CONTINUOUS_BATCHING: tl.constexpr, USE_PAD_SLOT: tl.constexpr, NP2_STATELEN: tl.constexpr, - DECODE_SEQLEN: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): @@ -416,7 +415,7 @@ def causal_conv1d_fn( activation = "silu" args = None - out = torch.zeros_like(x) + out = torch.empty_like(x) if metadata is not None: cu_seqlen = metadata.cu_seqlen nums_dict = metadata.nums_dict @@ -607,7 +606,6 @@ def grid(META): IS_CONTINUOUS_BATCHING=cache_indices is not None, USE_PAD_SLOT=pad_slot_id is not None, NP2_STATELEN=np2_statelen, - DECODE_SEQLEN=1, #launch_cooperative_grid=True BLOCK_M=8, BLOCK_N=256, @@ -665,7 +663,8 @@ def _causal_conv1d_update_kernel( if IS_CONTINUOUS_BATCHING: # mask = idx_seq < batch - conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq) + conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to( + tl.int64) else: conv_state_batch_coord = idx_seq if USE_PAD_SLOT: # noqa