-
-
Notifications
You must be signed in to change notification settings - Fork 8.9k
[Kernel] Triton implementation of causal-conv1d for Mamba-based models #18218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
ad83738
f4c56bf
dfa7159
61d7ed9
8882cef
775e561
939a823
29b7941
7bfe0e8
52d601c
081a8be
9eb1cc3
091b31e
bfabaae
da660f0
7af7f58
10e332c
ecb3a2c
107911a
bfc2f28
ef21b3d
400e669
f0be762
4cfb12d
64ee33d
19586c5
e3192e8
8aad208
a0d2170
4d1bb63
679eb1c
c782f25
6d0e77a
20a34c5
82091a7
6784173
6e8d966
089b10b
761bdea
7448f0d
5e41d6b
bbef3ac
6527b9d
129b32d
37f801a
a208d04
a798b14
736eeba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,12 +5,14 @@ | |
import pytest | ||
import torch | ||
import torch.nn.functional as F | ||
from einops import rearrange | ||
|
||
from tests.kernels.utils import opcheck | ||
from vllm import _custom_ops as ops # noqa: F401 | ||
from vllm.attention.backends.utils import PAD_SLOT_ID | ||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( | ||
causal_conv1d_fn, causal_conv1d_update) | ||
causal_conv1d_fn, causal_conv1d_fn_triton, causal_conv1d_update, | ||
causal_conv1d_update_triton) | ||
from vllm.platforms import current_platform | ||
|
||
|
||
|
@@ -435,3 +437,237 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, | |
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(), | ||
padded_state_indices, has_initial_states, | ||
final_states, activation) | ||
|
||
|
||
@pytest.mark.parametrize("itype", | ||
[torch.float32, torch.float16, torch.bfloat16]) | ||
@pytest.mark.parametrize("silu_activation", [False, True]) | ||
@pytest.mark.parametrize("has_bias", [False, True]) | ||
@pytest.mark.parametrize("seqlen", [1]) | ||
@pytest.mark.parametrize("width", [2, 3, 4]) | ||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) | ||
# tests correctness in case subset of the sequences are padded | ||
@pytest.mark.parametrize("with_padding", [True, False]) | ||
@pytest.mark.parametrize("batch_size", [3]) | ||
def test_causal_conv1d_update_with_batch_gather_vllm(batch_size, with_padding, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you reduce the amount of time spend in this test? It's taking 102s, and we should try to keep unit test time under control There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also could you remove the |
||
dim, width, seqlen, | ||
has_bias, silu_activation, | ||
itype): | ||
device = "cuda" | ||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) | ||
if itype == torch.bfloat16: | ||
rtol, atol = 1e-2, 5e-2 | ||
|
||
# set seed | ||
current_platform.seed_everything(0) | ||
|
||
padding = 5 if with_padding else 0 | ||
padded_batch_size = batch_size + padding | ||
# total_entries = number of cache line | ||
total_entries = 10 * batch_size | ||
|
||
channel_last = True | ||
if not channel_last: | ||
x = torch.randn(padded_batch_size, | ||
dim, | ||
seqlen, | ||
device=device, | ||
dtype=itype) | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this? |
||
# x will be (batch, dim, seqlen) with contiguous along dim-axis | ||
x = torch.randn(padded_batch_size, | ||
seqlen, | ||
dim, | ||
device=device, | ||
dtype=itype).transpose(1, 2) | ||
|
||
x_ref = x.clone() | ||
|
||
conv_state_indices = torch.randperm(total_entries)[:batch_size].to( | ||
dtype=torch.int32, device=device) | ||
unused_states_bool = torch.ones(total_entries, | ||
dtype=torch.bool, | ||
device=device) | ||
unused_states_bool[conv_state_indices] = False | ||
padded_state_indices = torch.concat([ | ||
conv_state_indices, | ||
torch.as_tensor( | ||
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device) | ||
], | ||
dim=0) | ||
|
||
if not channel_last: | ||
conv_state = torch.randn(total_entries, | ||
dim, | ||
width - 1, | ||
device=device, | ||
dtype=itype) | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
# conv_state will be (cache_lines, dim, state_len) | ||
# with contiguous along dim-axis | ||
conv_state = torch.randn(total_entries, | ||
width - 1, | ||
dim, | ||
device=device, | ||
dtype=itype).transpose(1, 2) | ||
|
||
conv_state_for_padding_test = conv_state.clone() | ||
|
||
weight = torch.randn(dim, width, device=device, dtype=itype) | ||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None | ||
conv_state_ref = conv_state[conv_state_indices, :].detach().clone() | ||
activation = None if not silu_activation else "silu" | ||
|
||
out = causal_conv1d_update_triton(x, | ||
conv_state, | ||
weight, | ||
bias, | ||
activation=activation, | ||
conv_state_indices=padded_state_indices, | ||
pad_slot_id=PAD_SLOT_ID) | ||
out_ref = causal_conv1d_update_ref(x_ref[:batch_size], | ||
conv_state_ref, | ||
weight, | ||
bias, | ||
activation=activation) | ||
|
||
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) | ||
assert torch.equal(conv_state[unused_states_bool], | ||
conv_state_for_padding_test[unused_states_bool]) | ||
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) | ||
|
||
|
||
@pytest.mark.parametrize("itype", [torch.bfloat16]) | ||
@pytest.mark.parametrize("silu_activation", [True]) | ||
@pytest.mark.parametrize("has_bias", [True]) | ||
@pytest.mark.parametrize("width", [4]) | ||
@pytest.mark.parametrize('seqlen', [8, 16, 784, 1024, 2048, 2049, 4096]) | ||
@pytest.mark.parametrize('dim', [64, 4096]) | ||
@pytest.mark.parametrize('with_padding', [True, False]) | ||
@pytest.mark.parametrize('batch', [4]) | ||
def test_causal_conv1d_varlen_vllm(batch, with_padding, dim, seqlen, width, | ||
has_bias, silu_activation, itype): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you rename this to |
||
device = "cuda" | ||
torch.cuda.empty_cache() | ||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) | ||
if itype == torch.bfloat16: | ||
rtol, atol = 1e-2, 5e-2 | ||
# set seed | ||
current_platform.seed_everything(0) | ||
seqlens = [] | ||
batch_size = batch | ||
padding = 3 if with_padding else 0 | ||
padded_batch_size = batch_size + padding | ||
nsplits = padded_batch_size - 1 | ||
|
||
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values | ||
|
||
seqlens.append( | ||
torch.diff( | ||
torch.cat( | ||
[torch.tensor([-1]), eos_pos, | ||
torch.tensor([seqlen - 1])])).tolist()) | ||
assert sum(seqlens[-1]) == seqlen | ||
assert all(s > 0 for s in seqlens[-1]) | ||
|
||
total_entries = batch_size * 10 | ||
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) | ||
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], | ||
dim=0) | ||
channel_last = True | ||
if not channel_last: | ||
x = torch.randn(1, 4096 + dim + 64, seqlen, device=device, | ||
dtype=itype)[:, 4096:4096 + dim, :] | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto, remove |
||
x = rearrange( | ||
torch.randn(1, seqlen, 4096 + dim + 64, device=device, | ||
dtype=itype), "b s d -> b d s")[:, 4096:4096 + dim, :] | ||
|
||
weight = torch.randn(dim, width, device=device, dtype=itype) | ||
|
||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None | ||
x_ref = x.clone() | ||
weight_ref = weight.clone() | ||
bias_ref = bias.clone() if bias is not None else None | ||
activation = None if not silu_activation else "silu" | ||
if not channel_last: | ||
final_states = torch.randn(total_entries, | ||
dim, | ||
width - 1, | ||
device=x.device, | ||
dtype=x.dtype) | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
final_states = torch.randn(total_entries, | ||
width - 1, | ||
dim, | ||
device=x.device, | ||
dtype=x.dtype).transpose(1, 2) | ||
final_states_ref = final_states.clone() | ||
has_initial_states = torch.randint(0, | ||
2, (cumsum.shape[0] - 1, ), | ||
dtype=torch.bool, | ||
device=x.device) | ||
state_indices = torch.randperm(total_entries, | ||
dtype=torch.int32, | ||
device=x.device)[:batch_size] | ||
padded_state_indices = torch.concat([ | ||
state_indices, | ||
torch.as_tensor( | ||
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), | ||
], | ||
dim=-1) | ||
out = causal_conv1d_fn_triton(x.squeeze(0), | ||
weight, | ||
bias=bias, | ||
conv_states=final_states, | ||
query_start_loc=cumsum.cuda(), | ||
cache_indices=padded_state_indices, | ||
has_initial_states=has_initial_states, | ||
activation=activation, | ||
pad_slot_id=PAD_SLOT_ID) | ||
|
||
out_ref = [] | ||
out_ref_b = [] | ||
|
||
splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)] | ||
for i in range(len(seqlens[0])): | ||
x_s = [v[i].unsqueeze(0) for v in splits][0] | ||
if padded_state_indices[i] == PAD_SLOT_ID: | ||
continue | ||
out_ref_b.append( | ||
causal_conv1d_ref( | ||
x_s, | ||
weight_ref, | ||
bias_ref, | ||
activation=activation, | ||
return_final_states=True, | ||
final_states_out=final_states_ref[ | ||
padded_state_indices[i]].unsqueeze(0), | ||
initial_states=final_states_ref[padded_state_indices[i]]. | ||
unsqueeze(0) if has_initial_states[i] else None)) | ||
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2)) | ||
out_ref_tensor = torch.cat(out_ref, dim=0) | ||
|
||
try: | ||
assert torch.allclose(final_states[state_indices], | ||
final_states_ref[state_indices], | ||
rtol=rtol, | ||
atol=atol) | ||
print("Passed conv_state") | ||
except Exception as e: | ||
print("FAILED conv_state") | ||
raise e | ||
unpadded_out = out[:, :out_ref_tensor.shape[-1]] | ||
try: | ||
assert torch.allclose(unpadded_out, | ||
out_ref_tensor, | ||
rtol=rtol, | ||
atol=atol) | ||
except Exception as e: | ||
input( | ||
"Passed conv_state, but failed output: Press Enter to continue...") | ||
thoangtrvn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
nz = out_ref_tensor.squeeze(0) - unpadded_out | ||
non_zero_indices = torch.nonzero(nz) | ||
print('nonzero indices :', non_zero_indices) | ||
raise e |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import math | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from vllm.attention.backends.abstract import AttentionMetadata | ||
|
@@ -22,6 +24,31 @@ class Mamba2Metadata: | |
chunk_indices: torch.Tensor | ||
chunk_offsets: torch.Tensor | ||
|
||
num_cache_lines: Optional[int] = None | ||
stride_istate_seq: Optional[int] = None | ||
stride_istate_dim: Optional[int] = None | ||
stride_istate_token: Optional[int] = None | ||
seqlens: Optional[np.ndarray] = None | ||
padded_batch: Optional[int] = None | ||
nums_dict: Optional[dict] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in a batch of requests, a prefill request can be processed in parallel where each Triton program handles BLOCK_M tokens. Depending on the choice of BLOCK_M, the values in
I added the documents accordingly. |
||
is_channel_last: bool = True | ||
stride_w_dim: Optional[int] = None | ||
stride_w_width: Optional[int] = None | ||
width: Optional[int] = None | ||
np2_statelen: Optional[int] = None | ||
stride_x_seq: Optional[int] = 0 | ||
stride_x_dim: Optional[int] = None | ||
stride_x_token: Optional[int] = None | ||
dim: Optional[int] = None | ||
cu_seqlen: Optional[int] = None | ||
out: Optional[torch.Tensor] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you explain what this is? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Triton causal-conv1d (prefill or mixed prefill/decode) kernel process input There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since |
||
stride_o_seq: Optional[int] = 0 | ||
stride_o_dim: Optional[int] = None | ||
stride_o_token: Optional[int] = None | ||
MAX_NUM_PROGRAMS: int = 1024 | ||
batch_ptr: Optional[torch.tensor] = None | ||
token_chunk_offset_ptr: Optional[torch.tensor] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a lot of stuff in here, and it's not really clear what most of it is for. At first glance it seems like most of this should be accessed on the fly instead of stored in the metadata here. Could you take a stab at cleaning this up? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The information here is reused across Mamba layers, even the stride call would trigger Torch calls, which triggers an unnecessary overhead. I can adds a description as needed. Please let me know what you think @tlrmchlsmth There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done updated code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think most of this should be removed. For CPU overheads in decode, we can rely on CUDA graphs and for prefill they are amortized There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm running some test to revert this. DONE There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tlrmchlsmth : done the update based on your feedback. |
||
|
||
|
||
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, | ||
chunk_size: int, | ||
|
@@ -62,7 +89,9 @@ def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, | |
def prepare_mamba2_metadata( | ||
chunk_size: int, | ||
attn_metadata: AttentionMetadata, | ||
mamba2_metadata=None, | ||
) -> Mamba2Metadata: | ||
# ruff: noqa: E501 | ||
tlrmchlsmth marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# compute number of prefill and decode requests | ||
# NOTE: in V0 we assume prefills are before decodes | ||
|
@@ -78,6 +107,12 @@ def prepare_mamba2_metadata( | |
|
||
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only | ||
if num_prefills > 0: | ||
# NOTE: currently it is assumed prefill requests come before decode requests -> we can use ':num_prefills' slicing | ||
# TODO: maybe revert back to the original code (below) if above no longer holds | ||
# has_initial_states = attn_metadata.context_lens_tensor > 0 | ||
# zero_init_indices = mamba_cache_params.state_indices_tensor[~has_initial_states] | ||
# mamba_cache_params.ssm_state[zero_init_indices] = 0 | ||
# initial_states = mamba_cache_params.ssm_state[mamba_cache_params.state_indices_tensor] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can rely on batch reordering and require that it be used for this implementation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The change from a previous PR (cuda-split) implies this assumption, this PR doesn't have this assumption which makes is more suitable for vLLM v1 design, the comment I added here is to clarify the code path from the previous PR. I can remove the comment as needed. Please let me know @tlrmchlsmth There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed comment There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, please remove the comment. We can rely on batch reordering even in vLLM V1, so this is a non issue There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
if (isinstance(attn_metadata, | ||
(FlashAttentionMetadata, XFormersMetadata, | ||
PlaceholderAttentionMetadata)) | ||
|
@@ -103,6 +138,21 @@ def prepare_mamba2_metadata( | |
_query_start_loc_to_chunk_indices_offsets( | ||
query_start_loc, chunk_size, num_prefill_tokens) | ||
|
||
if mamba2_metadata is not None: | ||
mamba2_metadata.has_initial_states = has_initial_states | ||
mamba2_metadata.prep_initial_states = prep_initial_states | ||
mamba2_metadata.chunk_size = chunk_size | ||
mamba2_metadata.seq_idx = seq_idx | ||
mamba2_metadata.chunk_indices = chunk_indices | ||
mamba2_metadata.chunk_offsets = chunk_offsets | ||
# We use 2 reset flags: | ||
# * mamba2_metadata.width is None # update config at first run (never change whole session for a given model) | ||
# (become available at first layer, e.g. conv_weights) | ||
# * mamba2_metadata.cu_seqlen is None # update config specific to (each input) | ||
# (become available at first layer, e.g. conv_weights) | ||
mamba2_metadata.cu_seqlen = None # suppose to be updated at each input | ||
|
||
return mamba2_metadata | ||
return Mamba2Metadata(has_initial_states=has_initial_states, | ||
prep_initial_states=prep_initial_states, | ||
chunk_size=chunk_size, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we test more than just seqlen 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yes, I can add more to the test code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done