Skip to content

Commit dc53980

Browse files
committed
Optimize ROCm global to LDS transfer in sparse Marlin MMA
Replace __builtin_amdgcn_global_load_lds with inline assembly using ds_load_b instruction for more precise and direct global to local data store (LDS) transfer on MI300X AMD GPUs.
1 parent 04014e7 commit dc53980

File tree

1 file changed

+15
-3
lines changed
  • torchao/csrc/cuda/sparse_marlin

1 file changed

+15
-3
lines changed

torchao/csrc/cuda/sparse_marlin/mem.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,11 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr,
5151
int src_in_bytes = (zfill ? 0 : BYTES);
5252
uint32_t smem = cvta_to_shared(smem_ptr);
5353
#ifdef USE_ROCM
54-
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
54+
// Use LDS.G instruction for global to LDS transfer on MI300X
55+
asm volatile(
56+
"{\n"
57+
" ds_load_b%c2 %0, %1\n"
58+
"}\n" :: "v"(smem), "v"(glob_ptr), "i"(BYTES));
5559
#else
5660
asm volatile(
5761
"{\n"
@@ -68,7 +72,11 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
6872
const int BYTES = 16;
6973
uint32_t smem = cvta_to_shared(smem_ptr);
7074
#ifdef USE_ROCM
71-
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
75+
// Use LDS.G instruction for global to LDS transfer on MI300X
76+
asm volatile(
77+
"{\n"
78+
" ds_load_b%c2 %0, %1\n"
79+
"}\n" :: "v"(smem), "v"(glob_ptr), "i"(BYTES));
7280
#else
7381
asm volatile(
7482
"{\n"
@@ -85,7 +93,11 @@ __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
8593
const int BYTES = 16;
8694
uint32_t smem = cvta_to_shared(smem_ptr);
8795
#ifdef USE_ROCM
88-
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
96+
// Use LDS.G instruction for global to LDS transfer on MI300X
97+
asm volatile(
98+
"{\n"
99+
" ds_load_b%c2 %0, %1\n"
100+
"}\n" :: "v"(smem), "v"(glob_ptr), "i"(BYTES));
89101
#else
90102
asm volatile(
91103
"{\n"

0 commit comments

Comments
 (0)