Skip to content

[Kernel] Triton implementation of causal-conv1d for Mamba-based models #18218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 48 commits into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
ad83738
add causal-conv1d in Triton and integrate into vLLM with test code
tmhoangt May 15, 2025
f4c56bf
add causal-conv1d in Triton and integrate into vLLM with test code
tmhoangt May 15, 2025
dfa7159
resolve merge conflict
tmhoangt May 15, 2025
61d7ed9
fix a bug when migrating code to vLLM
tmhoangt May 15, 2025
8882cef
fix a bug when migrating code to vLLM
tmhoangt May 15, 2025
775e561
refactor for code style
tmhoangt May 16, 2025
939a823
refactor for code style
tmhoangt May 16, 2025
29b7941
refactor for code style
tmhoangt May 16, 2025
7bfe0e8
refactor for code style
tmhoangt May 16, 2025
52d601c
refactor for code style
tmhoangt May 16, 2025
081a8be
Update tests/kernels/mamba/test_causal_conv1d.py
thoangtrvn Jun 2, 2025
9eb1cc3
update test code to cover more use-cases
tmhoangt Jun 2, 2025
091b31e
refactor code based on feedback
tmhoangt Jun 4, 2025
bfabaae
refactor code based on feedback
tmhoangt Jun 4, 2025
da660f0
refactor code based on feedback
tmhoangt Jun 4, 2025
7af7f58
refactor code based on feedback
tmhoangt Jun 4, 2025
10e332c
Merge branch 'main' into pr_conv1d_triton
thoangtrvn Jun 4, 2025
ecb3a2c
refactor code based on feedback
tmhoangt Jun 4, 2025
107911a
refactor code based on feedback
tmhoangt Jun 4, 2025
bfc2f28
refactor code to fix mypy codecheck
tmhoangt Jun 4, 2025
ef21b3d
refactor code to fix mypy codecheck
tmhoangt Jun 4, 2025
400e669
Merge branch 'pr_conv1d_triton' of github.com:thoangtrvn/vllm into pr…
tmhoangt Jun 4, 2025
f0be762
refactor code to fix mypy codecheck
tmhoangt Jun 4, 2025
4cfb12d
revert code change based on feedback
tmhoangt Jun 5, 2025
64ee33d
revert code change based on feedback
tmhoangt Jun 5, 2025
19586c5
revert code change based on feedback
tmhoangt Jun 5, 2025
e3192e8
migrate code change based on feedback
tmhoangt Jun 5, 2025
8aad208
migrate code change based on feedback
tmhoangt Jun 5, 2025
a0d2170
revert code change based on feedback
tmhoangt Jun 5, 2025
4d1bb63
revert code change based on feedback
tmhoangt Jun 5, 2025
679eb1c
migrate code change based on feedback
tmhoangt Jun 5, 2025
c782f25
fix merge conflict from upstream/main
tmhoangt Jun 5, 2025
6d0e77a
reduce kernel test time
tmhoangt Jun 10, 2025
20a34c5
remove CUDA causal-conv1d kernel
tmhoangt Jun 10, 2025
82091a7
Merge remote-tracking branch 'upstream/main' into pr_conv1d_triton
tmhoangt Jun 10, 2025
6784173
remove unused code based on feedback
tmhoangt Jun 10, 2025
6e8d966
update argument name
tmhoangt Jun 11, 2025
089b10b
Merge remote-tracking branch 'upstream/main' into pr_conv1d_triton
tmhoangt Jun 26, 2025
761bdea
Use typing.Union to work with Python 3.9
tmhoangt Jun 26, 2025
7448f0d
move _query_start_loc_to_chunk_indices_offsets to mamba_attn.py to avoid
tmhoangt Jun 26, 2025
5e41d6b
Merge remote-tracking branch 'upstream/main' into pr_conv1d_triton
tmhoangt Jun 28, 2025
bbef3ac
Update vllm/v1/attention/backends/mamba_attn.py
thoangtrvn Jun 30, 2025
6527b9d
revert space change in zamba2.py and address comments
tmhoangt Jun 30, 2025
129b32d
revert to using `has_initial_state` argument for causal_conv1d_fn, fix
tmhoangt Jul 8, 2025
37f801a
revert to using `has_initial_state` argument for causal_conv1d_fn
tmhoangt Jul 8, 2025
a208d04
update code to work in v1
tmhoangt Jul 8, 2025
a798b14
make typing compatible Python 3.9
tmhoangt Jul 8, 2025
736eeba
Merge branch 'main' into pr_conv1d_triton
tlrmchlsmth Jul 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 237 additions & 1 deletion tests/kernels/mamba/test_causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange

from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
causal_conv1d_fn, causal_conv1d_fn_triton, causal_conv1d_update,
causal_conv1d_update_triton)
from vllm.platforms import current_platform


Expand Down Expand Up @@ -435,3 +437,237 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
padded_state_indices, has_initial_states,
final_states, activation)


@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we test more than just seqlen 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes, I can add more to the test code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@pytest.mark.parametrize("width", [2, 3, 4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
@pytest.mark.parametrize("batch_size", [3])
def test_causal_conv1d_update_with_batch_gather_vllm(batch_size, with_padding,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you reduce the amount of time spend in this test? It's taking 102s, and we should try to keep unit test time under control

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also could you remove the _vllm suffix?

dim, width, seqlen,
has_bias, silu_activation,
itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2

# set seed
current_platform.seed_everything(0)

padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding
# total_entries = number of cache line
total_entries = 10 * batch_size

channel_last = True
if not channel_last:
x = torch.randn(padded_batch_size,
dim,
seqlen,
device=device,
dtype=itype)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this?

# x will be (batch, dim, seqlen) with contiguous along dim-axis
x = torch.randn(padded_batch_size,
seqlen,
dim,
device=device,
dtype=itype).transpose(1, 2)

x_ref = x.clone()

conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device)
unused_states_bool = torch.ones(total_entries,
dtype=torch.bool,
device=device)
unused_states_bool[conv_state_indices] = False
padded_state_indices = torch.concat([
conv_state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
],
dim=0)

if not channel_last:
conv_state = torch.randn(total_entries,
dim,
width - 1,
device=device,
dtype=itype)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

# conv_state will be (cache_lines, dim, state_len)
# with contiguous along dim-axis
conv_state = torch.randn(total_entries,
width - 1,
dim,
device=device,
dtype=itype).transpose(1, 2)

conv_state_for_padding_test = conv_state.clone()

weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
activation = None if not silu_activation else "silu"

out = causal_conv1d_update_triton(x,
conv_state,
weight,
bias,
activation=activation,
conv_state_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID)
out_ref = causal_conv1d_update_ref(x_ref[:batch_size],
conv_state_ref,
weight,
bias,
activation=activation)

assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.equal(conv_state[unused_states_bool],
conv_state_for_padding_test[unused_states_bool])
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)


@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize('seqlen', [8, 16, 784, 1024, 2048, 2049, 4096])
@pytest.mark.parametrize('dim', [64, 4096])
@pytest.mark.parametrize('with_padding', [True, False])
@pytest.mark.parametrize('batch', [4])
def test_causal_conv1d_varlen_vllm(batch, with_padding, dim, seqlen, width,
has_bias, silu_activation, itype):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you rename this to test_causal_conv1d_varlen? We don't need to clarify that it's for vllm, since this is in the vllm codebase

device = "cuda"
torch.cuda.empty_cache()
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
current_platform.seed_everything(0)
seqlens = []
batch_size = batch
padding = 3 if with_padding else 0
padded_batch_size = batch_size + padding
nsplits = padded_batch_size - 1

eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values

seqlens.append(
torch.diff(
torch.cat(
[torch.tensor([-1]), eos_pos,
torch.tensor([seqlen - 1])])).tolist())
assert sum(seqlens[-1]) == seqlen
assert all(s > 0 for s in seqlens[-1])

total_entries = batch_size * 10
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
dim=0)
channel_last = True
if not channel_last:
x = torch.randn(1, 4096 + dim + 64, seqlen, device=device,
dtype=itype)[:, 4096:4096 + dim, :]
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, remove

x = rearrange(
torch.randn(1, seqlen, 4096 + dim + 64, device=device,
dtype=itype), "b s d -> b d s")[:, 4096:4096 + dim, :]

weight = torch.randn(dim, width, device=device, dtype=itype)

bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
x_ref = x.clone()
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
activation = None if not silu_activation else "silu"
if not channel_last:
final_states = torch.randn(total_entries,
dim,
width - 1,
device=x.device,
dtype=x.dtype)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

final_states = torch.randn(total_entries,
width - 1,
dim,
device=x.device,
dtype=x.dtype).transpose(1, 2)
final_states_ref = final_states.clone()
has_initial_states = torch.randint(0,
2, (cumsum.shape[0] - 1, ),
dtype=torch.bool,
device=x.device)
state_indices = torch.randperm(total_entries,
dtype=torch.int32,
device=x.device)[:batch_size]
padded_state_indices = torch.concat([
state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=-1)
out = causal_conv1d_fn_triton(x.squeeze(0),
weight,
bias=bias,
conv_states=final_states,
query_start_loc=cumsum.cuda(),
cache_indices=padded_state_indices,
has_initial_states=has_initial_states,
activation=activation,
pad_slot_id=PAD_SLOT_ID)

out_ref = []
out_ref_b = []

splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
for i in range(len(seqlens[0])):
x_s = [v[i].unsqueeze(0) for v in splits][0]
if padded_state_indices[i] == PAD_SLOT_ID:
continue
out_ref_b.append(
causal_conv1d_ref(
x_s,
weight_ref,
bias_ref,
activation=activation,
return_final_states=True,
final_states_out=final_states_ref[
padded_state_indices[i]].unsqueeze(0),
initial_states=final_states_ref[padded_state_indices[i]].
unsqueeze(0) if has_initial_states[i] else None))
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
out_ref_tensor = torch.cat(out_ref, dim=0)

try:
assert torch.allclose(final_states[state_indices],
final_states_ref[state_indices],
rtol=rtol,
atol=atol)
print("Passed conv_state")
except Exception as e:
print("FAILED conv_state")
raise e
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
try:
assert torch.allclose(unpadded_out,
out_ref_tensor,
rtol=rtol,
atol=atol)
except Exception as e:
input(
"Passed conv_state, but failed output: Press Enter to continue...")

nz = out_ref_tensor.squeeze(0) - unpadded_out
non_zero_indices = torch.nonzero(nz)
print('nonzero indices :', non_zero_indices)
raise e
50 changes: 50 additions & 0 deletions vllm/model_executor/layers/mamba/mamba2_metadata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
import math
from dataclasses import dataclass
from typing import Optional

import numpy as np
import torch

from vllm.attention.backends.abstract import AttentionMetadata
Expand All @@ -22,6 +24,31 @@ class Mamba2Metadata:
chunk_indices: torch.Tensor
chunk_offsets: torch.Tensor

num_cache_lines: Optional[int] = None
stride_istate_seq: Optional[int] = None
stride_istate_dim: Optional[int] = None
stride_istate_token: Optional[int] = None
seqlens: Optional[np.ndarray] = None
padded_batch: Optional[int] = None
nums_dict: Optional[dict] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is nums_dict? This should be documented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in a batch of requests, a prefill request can be processed in parallel where each Triton program handles BLOCK_M tokens. Depending on the choice of BLOCK_M, the values in batch_ptr and token_chunk_offset_ptr can be different. The choice of BLOCK_M can be different at different inputs, or different hardware. Currently, BLOCK_M is chosen as 8 and is the same across all inputs which is a good choice to avoid the overhead in Triton autotune.

nums_dict[BLOCK_M] = {batch_ptr, token_chunk_offset_ptr}

I added the documents accordingly.

is_channel_last: bool = True
stride_w_dim: Optional[int] = None
stride_w_width: Optional[int] = None
width: Optional[int] = None
np2_statelen: Optional[int] = None
stride_x_seq: Optional[int] = 0
stride_x_dim: Optional[int] = None
stride_x_token: Optional[int] = None
dim: Optional[int] = None
cu_seqlen: Optional[int] = None
out: Optional[torch.Tensor] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain what this is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Triton causal-conv1d (prefill or mixed prefill/decode) kernel process input x using conv_state. While conv_state update must be in-place, the output is written to out tensor rather than writing to x to avoid race condition - as each Triton program handles one segment of the request (unlike CUDA kernel where one thread block handles one full request). out is reused across all layers .

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since out isn't metadata, please remove it from Mamba2Metadata and treat it like a normal tensor

stride_o_seq: Optional[int] = 0
stride_o_dim: Optional[int] = None
stride_o_token: Optional[int] = None
MAX_NUM_PROGRAMS: int = 1024
batch_ptr: Optional[torch.tensor] = None
token_chunk_offset_ptr: Optional[torch.tensor] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a lot of stuff in here, and it's not really clear what most of it is for. At first glance it seems like most of this should be accessed on the fly instead of stored in the metadata here. Could you take a stab at cleaning this up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The information here is reused across Mamba layers, even the stride call would trigger Torch calls, which triggers an unnecessary overhead. I can adds a description as needed. Please let me know what you think @tlrmchlsmth

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done updated code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think most of this should be removed. For CPU overheads in decode, we can rely on CUDA graphs and for prefill they are amortized

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm running some test to revert this.

DONE

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tlrmchlsmth : done the update based on your feedback.



def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
chunk_size: int,
Expand Down Expand Up @@ -62,7 +89,9 @@ def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
def prepare_mamba2_metadata(
chunk_size: int,
attn_metadata: AttentionMetadata,
mamba2_metadata=None,
) -> Mamba2Metadata:
# ruff: noqa: E501

# compute number of prefill and decode requests
# NOTE: in V0 we assume prefills are before decodes
Expand All @@ -78,6 +107,12 @@ def prepare_mamba2_metadata(

# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
if num_prefills > 0:
# NOTE: currently it is assumed prefill requests come before decode requests -> we can use ':num_prefills' slicing
# TODO: maybe revert back to the original code (below) if above no longer holds
# has_initial_states = attn_metadata.context_lens_tensor > 0
# zero_init_indices = mamba_cache_params.state_indices_tensor[~has_initial_states]
# mamba_cache_params.ssm_state[zero_init_indices] = 0
# initial_states = mamba_cache_params.ssm_state[mamba_cache_params.state_indices_tensor]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can rely on batch reordering and require that it be used for this implementation.

Copy link
Contributor Author

@thoangtrvn thoangtrvn Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change from a previous PR (cuda-split) implies this assumption, this PR doesn't have this assumption which makes is more suitable for vLLM v1 design, the comment I added here is to clarify the code path from the previous PR. I can remove the comment as needed. Please let me know @tlrmchlsmth

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed comment

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, please remove the comment.

We can rely on batch reordering even in vLLM V1, so this is a non issue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if (isinstance(attn_metadata,
(FlashAttentionMetadata, XFormersMetadata,
PlaceholderAttentionMetadata))
Expand All @@ -103,6 +138,21 @@ def prepare_mamba2_metadata(
_query_start_loc_to_chunk_indices_offsets(
query_start_loc, chunk_size, num_prefill_tokens)

if mamba2_metadata is not None:
mamba2_metadata.has_initial_states = has_initial_states
mamba2_metadata.prep_initial_states = prep_initial_states
mamba2_metadata.chunk_size = chunk_size
mamba2_metadata.seq_idx = seq_idx
mamba2_metadata.chunk_indices = chunk_indices
mamba2_metadata.chunk_offsets = chunk_offsets
# We use 2 reset flags:
# * mamba2_metadata.width is None # update config at first run (never change whole session for a given model)
# (become available at first layer, e.g. conv_weights)
# * mamba2_metadata.cu_seqlen is None # update config specific to (each input)
# (become available at first layer, e.g. conv_weights)
mamba2_metadata.cu_seqlen = None # suppose to be updated at each input

return mamba2_metadata
return Mamba2Metadata(has_initial_states=has_initial_states,
prep_initial_states=prep_initial_states,
chunk_size=chunk_size,
Expand Down
Loading