Skip to content

Commit 47043eb

Browse files
thoangtrvntmhoangttlrmchlsmth
authored
[Kernel] Triton implementation of causal-conv1d for Mamba-based models (#18218)
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com> Co-authored-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
1 parent 31b96d1 commit 47043eb

File tree

15 files changed

+1120
-1145
lines changed

15 files changed

+1120
-1145
lines changed

CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,6 @@ endif()
232232

233233
set(VLLM_EXT_SRC
234234
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
235-
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
236235
"csrc/cache_kernels.cu"
237236
"csrc/attention/paged_attention_v1.cu"
238237
"csrc/attention/paged_attention_v2.cu"

csrc/mamba/causal_conv1d/causal_conv1d.cu

Lines changed: 0 additions & 656 deletions
This file was deleted.

csrc/mamba/causal_conv1d/causal_conv1d.h

Lines changed: 0 additions & 159 deletions
This file was deleted.

csrc/mamba/causal_conv1d/static_switch.h

Lines changed: 0 additions & 28 deletions
This file was deleted.

csrc/ops.h

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -326,22 +326,6 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
326326
const std::optional<torch::Tensor>& has_initial_state,
327327
const torch::Tensor& ssm_states, int64_t pad_slot_id);
328328

329-
void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
330-
const at::Tensor& weight,
331-
const std::optional<at::Tensor>& bias_,
332-
bool silu_activation,
333-
const std::optional<at::Tensor>& cache_seqlens_,
334-
const std::optional<at::Tensor>& conv_state_indices_,
335-
int64_t pad_slot_id);
336-
337-
void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
338-
const std::optional<at::Tensor>& bias_,
339-
const std::optional<at::Tensor>& conv_states,
340-
const std::optional<at::Tensor>& query_start_loc,
341-
const std::optional<at::Tensor>& cache_indices,
342-
const std::optional<at::Tensor>& has_initial_state,
343-
bool silu_activation, int64_t pad_slot_id);
344-
345329
using fptr_t = int64_t;
346330
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
347331
torch::Tensor& rank_data, int64_t rank,

csrc/torch_bindings.cpp

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -594,28 +594,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
594594
"int pad_slot_id) -> ()");
595595
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
596596

597-
ops.def(
598-
"causal_conv1d_update(Tensor! x,"
599-
"Tensor! conv_state,"
600-
"Tensor! weight,"
601-
"Tensor? bias_,"
602-
"bool silu_activation,"
603-
"Tensor? cache_seqlens_,"
604-
"Tensor? conv_state_indices,"
605-
"int pad_slot_id) -> ()");
606-
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
607-
608-
ops.def(
609-
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
610-
"Tensor? bias_,"
611-
"Tensor!? conv_states,"
612-
"Tensor? query_start_loc,"
613-
"Tensor? cache_indices,"
614-
"Tensor? has_initial_state,"
615-
"bool silu_activation,"
616-
"int pad_slot_id) -> ()");
617-
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
618-
619597
#ifndef USE_ROCM
620598
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
621599
ops.def(

0 commit comments

Comments
 (0)