Skip to content

Commit d90f51c

Browse files
tjtanaavllmellm
authored andcommitted
[ROCm] [Bugfix] [Critical]: Fix mamba compilation bug (vllm-project#20883)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
1 parent f04a690 commit d90f51c

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

csrc/mamba/mamba_ssm/selective_scan_fwd.cu

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77

88
#include <c10/util/BFloat16.h>
99
#include <c10/util/Half.h>
10-
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
10+
#ifdef USE_ROCM
11+
#include <c10/hip/HIPException.h> // For C10_HIP_CHECK and C10_HIP_KERNEL_LAUNCH_CHECK
12+
#else
13+
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
14+
#endif
1115

1216
#ifndef USE_ROCM
1317
#include <cub/block/block_load.cuh>
@@ -320,8 +324,13 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
320324
dim3 grid(params.batch, params.dim / kNRows);
321325
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
322326
if (kSmemSize >= 48 * 1024) {
327+
#ifdef USE_ROCM
328+
C10_HIP_CHECK(hipFuncSetAttribute(
329+
reinterpret_cast<const void*>(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
330+
#else
323331
C10_CUDA_CHECK(cudaFuncSetAttribute(
324332
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
333+
#endif
325334
}
326335
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
327336
C10_CUDA_KERNEL_LAUNCH_CHECK();

0 commit comments

Comments
 (0)