Skip to content

Commit bf81763

Browse files
committed
Add ROCm support for Marlin kernel function attribute setting
Modify the Marlin kernel to conditionally set function attributes using HIP/ROCm-specific API when running on AMD platforms, while maintaining CUDA compatibility. This ensures proper dynamic shared memory configuration across different GPU architectures.
1 parent 6f43e01 commit bf81763

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -861,10 +861,19 @@ __global__ void Marlin_24(
861861
thread_n_blocks == THREAD_N_BLOCKS && \
862862
thread_k_blocks == THREAD_K_BLOCKS && \
863863
group_blocks == GROUP_BLOCKS) { \
864-
cudaFuncSetAttribute( \
865-
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
866-
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
867-
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
864+
#ifdef __HIP_PLATFORM_AMD__
865+
// For ROCm/HIP, cast the kernel function to const void* for hipFuncSetAttribute
866+
hipFuncSetAttribute(
867+
reinterpret_cast<const void*>(&Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS,
868+
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>),
869+
hipFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);
870+
#else
871+
// For CUDA, use the template function directly
872+
cudaFuncSetAttribute(
873+
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS,
874+
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>,
875+
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);
876+
#endif
868877
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
869878
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS> \
870879
<<<blocks, THREADS, max_shared_mem, stream>>>(A_ptr, B_ptr, meta_ptr, \

0 commit comments

Comments
 (0)