Skip to content

Commit a547aeb

Browse files
authored
feat(rocm-support): support mamba2 on rocm (vllm-project#18565)
Signed-off-by: Islam Almersawi <islam.almersawi@openinnovation.ai> Co-authored-by: Islam Almersawi <islam.almersawi@openinnovation.ai>
1 parent fc6d0c2 commit a547aeb

File tree

5 files changed

+60
-49
lines changed

5 files changed

+60
-49
lines changed

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ endif()
232232
#
233233

234234
set(VLLM_EXT_SRC
235+
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
236+
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
235237
"csrc/cache_kernels.cu"
236238
"csrc/attention/paged_attention_v1.cu"
237239
"csrc/attention/paged_attention_v2.cu"
@@ -287,8 +289,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
287289
FetchContent_MakeAvailable(cutlass)
288290

289291
list(APPEND VLLM_EXT_SRC
290-
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
291-
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
292292
"csrc/quantization/aqlm/gemm_kernels.cu"
293293
"csrc/quantization/awq/gemm_kernels.cu"
294294
"csrc/permute_cols.cu"

csrc/mamba/causal_conv1d/causal_conv1d.cu

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
#include <cub/block/block_load.cuh>
1414
#include <cub/block/block_store.cuh>
1515

16+
#ifdef USE_ROCM
17+
namespace cub = hipcub;
18+
#endif
19+
1620
#include "static_switch.h"
1721

1822

@@ -501,15 +505,9 @@ void causal_conv1d_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
501505
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
502506

503507
if (kSmemSize >= 48 * 1024) {
504-
#ifndef USE_ROCM
505-
C10_CUDA_CHECK(cudaFuncSetAttribute(
506-
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
507-
#else
508-
// There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
509508
C10_CUDA_CHECK(cudaFuncSetAttribute(
510509
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
511510
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
512-
#endif
513511
}
514512
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
515513

csrc/mamba/mamba_ssm/selective_scan_fwd.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
321321
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
322322
if (kSmemSize >= 48 * 1024) {
323323
C10_CUDA_CHECK(cudaFuncSetAttribute(
324-
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
324+
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
325325
}
326326
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
327327
C10_CUDA_KERNEL_LAUNCH_CHECK();

csrc/torch_bindings.cpp

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -482,41 +482,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
482482
" Tensor page_table, float scale) -> ()");
483483
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
484484

485-
// Mamba selective scan kernel
486-
ops.def(
487-
"selective_scan_fwd(Tensor! u, Tensor! delta,"
488-
"Tensor! A, Tensor! B, Tensor! C,"
489-
"Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
490-
"bool delta_softplus,"
491-
"Tensor? query_start_loc,"
492-
"Tensor? cache_indices,"
493-
"Tensor? has_initial_state,"
494-
"Tensor! ssm_states,"
495-
"int pad_slot_id) -> ()");
496-
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
497-
498-
ops.def(
499-
"causal_conv1d_update(Tensor! x,"
500-
"Tensor! conv_state,"
501-
"Tensor! weight,"
502-
"Tensor? bias_,"
503-
"bool silu_activation,"
504-
"Tensor? cache_seqlens_,"
505-
"Tensor? conv_state_indices,"
506-
"int pad_slot_id) -> ()");
507-
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
508-
509-
ops.def(
510-
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
511-
"Tensor? bias_,"
512-
"Tensor!? conv_states,"
513-
"Tensor? query_start_loc,"
514-
"Tensor? cache_indices,"
515-
"Tensor? has_initial_state,"
516-
"bool silu_activation,"
517-
"int pad_slot_id) -> ()");
518-
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
519-
520485
// Compute NVFP4 block quantized tensor.
521486
ops.def(
522487
"scaled_fp4_quant(Tensor! output, Tensor input,"
@@ -584,6 +549,41 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
584549
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
585550
&dynamic_scaled_int8_quant);
586551

552+
// Mamba selective scan kernel
553+
ops.def(
554+
"selective_scan_fwd(Tensor! u, Tensor! delta,"
555+
"Tensor! A, Tensor! B, Tensor! C,"
556+
"Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
557+
"bool delta_softplus,"
558+
"Tensor? query_start_loc,"
559+
"Tensor? cache_indices,"
560+
"Tensor? has_initial_state,"
561+
"Tensor! ssm_states,"
562+
"int pad_slot_id) -> ()");
563+
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
564+
565+
ops.def(
566+
"causal_conv1d_update(Tensor! x,"
567+
"Tensor! conv_state,"
568+
"Tensor! weight,"
569+
"Tensor? bias_,"
570+
"bool silu_activation,"
571+
"Tensor? cache_seqlens_,"
572+
"Tensor? conv_state_indices,"
573+
"int pad_slot_id) -> ()");
574+
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
575+
576+
ops.def(
577+
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
578+
"Tensor? bias_,"
579+
"Tensor!? conv_states,"
580+
"Tensor? query_start_loc,"
581+
"Tensor? cache_indices,"
582+
"Tensor? has_initial_state,"
583+
"bool silu_activation,"
584+
"int pad_slot_id) -> ()");
585+
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
586+
587587
#ifndef USE_ROCM
588588
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
589589
ops.def(

vllm/model_executor/layers/mamba/mamba2_metadata.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
import torch
66

77
from vllm.attention.backends.abstract import AttentionMetadata
8-
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
98
from vllm.attention.backends.placeholder_attn import (
109
PlaceholderAttentionMetadata)
11-
from vllm.attention.backends.xformers import XFormersMetadata
10+
from vllm.platforms import current_platform
1211

1312

1413
@dataclass
@@ -23,6 +22,21 @@ class Mamba2Metadata:
2322
chunk_offsets: torch.Tensor
2423

2524

25+
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
26+
"""Returns the appropriate metadata classes for the current platform."""
27+
if current_platform.is_rocm():
28+
from vllm.attention.backends.rocm_flash_attn import (
29+
ROCmFlashAttentionMetadata)
30+
return (ROCmFlashAttentionMetadata, PlaceholderAttentionMetadata)
31+
elif current_platform.is_cuda():
32+
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
33+
from vllm.attention.backends.xformers import XFormersMetadata
34+
return (FlashAttentionMetadata, XFormersMetadata,
35+
PlaceholderAttentionMetadata)
36+
raise ValueError(
37+
f"Unsupported platform for Mamba2: {current_platform.device_type}")
38+
39+
2640
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
2741
chunk_size: int,
2842
total_seqlens: int):
@@ -78,9 +92,8 @@ def prepare_mamba2_metadata(
7892

7993
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
8094
if num_prefills > 0:
81-
if (isinstance(attn_metadata,
82-
(FlashAttentionMetadata, XFormersMetadata,
83-
PlaceholderAttentionMetadata))
95+
attn_metadata_instances = get_platform_metadata_classes()
96+
if (isinstance(attn_metadata, attn_metadata_instances)
8497
and attn_metadata.context_lens_tensor is not None):
8598
has_initial_states = \
8699
attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,]

0 commit comments

Comments
 (0)