Skip to content

[Kernel] Triton implementation of causal-conv1d for Mamba-based models #18218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 48 commits into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
ad83738
add causal-conv1d in Triton and integrate into vLLM with test code
tmhoangt May 15, 2025
f4c56bf
add causal-conv1d in Triton and integrate into vLLM with test code
tmhoangt May 15, 2025
dfa7159
resolve merge conflict
tmhoangt May 15, 2025
61d7ed9
fix a bug when migrating code to vLLM
tmhoangt May 15, 2025
8882cef
fix a bug when migrating code to vLLM
tmhoangt May 15, 2025
775e561
refactor for code style
tmhoangt May 16, 2025
939a823
refactor for code style
tmhoangt May 16, 2025
29b7941
refactor for code style
tmhoangt May 16, 2025
7bfe0e8
refactor for code style
tmhoangt May 16, 2025
52d601c
refactor for code style
tmhoangt May 16, 2025
081a8be
Update tests/kernels/mamba/test_causal_conv1d.py
thoangtrvn Jun 2, 2025
9eb1cc3
update test code to cover more use-cases
tmhoangt Jun 2, 2025
091b31e
refactor code based on feedback
tmhoangt Jun 4, 2025
bfabaae
refactor code based on feedback
tmhoangt Jun 4, 2025
da660f0
refactor code based on feedback
tmhoangt Jun 4, 2025
7af7f58
refactor code based on feedback
tmhoangt Jun 4, 2025
10e332c
Merge branch 'main' into pr_conv1d_triton
thoangtrvn Jun 4, 2025
ecb3a2c
refactor code based on feedback
tmhoangt Jun 4, 2025
107911a
refactor code based on feedback
tmhoangt Jun 4, 2025
bfc2f28
refactor code to fix mypy codecheck
tmhoangt Jun 4, 2025
ef21b3d
refactor code to fix mypy codecheck
tmhoangt Jun 4, 2025
400e669
Merge branch 'pr_conv1d_triton' of github.com:thoangtrvn/vllm into pr…
tmhoangt Jun 4, 2025
f0be762
refactor code to fix mypy codecheck
tmhoangt Jun 4, 2025
4cfb12d
revert code change based on feedback
tmhoangt Jun 5, 2025
64ee33d
revert code change based on feedback
tmhoangt Jun 5, 2025
19586c5
revert code change based on feedback
tmhoangt Jun 5, 2025
e3192e8
migrate code change based on feedback
tmhoangt Jun 5, 2025
8aad208
migrate code change based on feedback
tmhoangt Jun 5, 2025
a0d2170
revert code change based on feedback
tmhoangt Jun 5, 2025
4d1bb63
revert code change based on feedback
tmhoangt Jun 5, 2025
679eb1c
migrate code change based on feedback
tmhoangt Jun 5, 2025
c782f25
fix merge conflict from upstream/main
tmhoangt Jun 5, 2025
6d0e77a
reduce kernel test time
tmhoangt Jun 10, 2025
20a34c5
remove CUDA causal-conv1d kernel
tmhoangt Jun 10, 2025
82091a7
Merge remote-tracking branch 'upstream/main' into pr_conv1d_triton
tmhoangt Jun 10, 2025
6784173
remove unused code based on feedback
tmhoangt Jun 10, 2025
6e8d966
update argument name
tmhoangt Jun 11, 2025
089b10b
Merge remote-tracking branch 'upstream/main' into pr_conv1d_triton
tmhoangt Jun 26, 2025
761bdea
Use typing.Union to work with Python 3.9
tmhoangt Jun 26, 2025
7448f0d
move _query_start_loc_to_chunk_indices_offsets to mamba_attn.py to avoid
tmhoangt Jun 26, 2025
5e41d6b
Merge remote-tracking branch 'upstream/main' into pr_conv1d_triton
tmhoangt Jun 28, 2025
bbef3ac
Update vllm/v1/attention/backends/mamba_attn.py
thoangtrvn Jun 30, 2025
6527b9d
revert space change in zamba2.py and address comments
tmhoangt Jun 30, 2025
129b32d
revert to using `has_initial_state` argument for causal_conv1d_fn, fix
tmhoangt Jul 8, 2025
37f801a
revert to using `has_initial_state` argument for causal_conv1d_fn
tmhoangt Jul 8, 2025
a208d04
update code to work in v1
tmhoangt Jul 8, 2025
a798b14
make typing compatible Python 3.9
tmhoangt Jul 8, 2025
736eeba
Merge branch 'main' into pr_conv1d_triton
tlrmchlsmth Jul 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ endif()

set(VLLM_EXT_SRC
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
"csrc/cache_kernels.cu"
"csrc/attention/paged_attention_v1.cu"
"csrc/attention/paged_attention_v2.cu"
Expand Down
656 changes: 0 additions & 656 deletions csrc/mamba/causal_conv1d/causal_conv1d.cu

This file was deleted.

159 changes: 0 additions & 159 deletions csrc/mamba/causal_conv1d/causal_conv1d.h

This file was deleted.

28 changes: 0 additions & 28 deletions csrc/mamba/causal_conv1d/static_switch.h

This file was deleted.

16 changes: 0 additions & 16 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,22 +326,6 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const std::optional<torch::Tensor>& has_initial_state,
const torch::Tensor& ssm_states, int64_t pad_slot_id);

void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
const at::Tensor& weight,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove pybind11 C code and CUDA kernel

const std::optional<at::Tensor>& bias_,
bool silu_activation,
const std::optional<at::Tensor>& cache_seqlens_,
const std::optional<at::Tensor>& conv_state_indices_,
int64_t pad_slot_id);

void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
const std::optional<at::Tensor>& bias_,
const std::optional<at::Tensor>& conv_states,
const std::optional<at::Tensor>& query_start_loc,
const std::optional<at::Tensor>& cache_indices,
const std::optional<at::Tensor>& has_initial_state,
bool silu_activation, int64_t pad_slot_id);

using fptr_t = int64_t;
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
torch::Tensor& rank_data, int64_t rank,
Expand Down
22 changes: 0 additions & 22 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -594,28 +594,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int pad_slot_id) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);

ops.def(
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor? bias_,"
"bool silu_activation,"
"Tensor? cache_seqlens_,"
"Tensor? conv_state_indices,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);

ops.def(
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
"Tensor? bias_,"
"Tensor!? conv_states,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"bool silu_activation,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);

#ifndef USE_ROCM
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
ops.def(
Expand Down
Loading