Skip to content

Commit 8105c67

Browse files
committed
musa: disable mudnnMemcpyAsync by default
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
1 parent 52cc72c commit 8105c67

File tree

3 files changed

+19
-8
lines changed

3 files changed

+19
-8
lines changed

ggml/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,8 @@ option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental,
174174
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
175175
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
176176
option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 "ggml: enable rocWMMA FlashAttention on GFX12" OFF)
177-
option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental" OFF)
177+
option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF)
178+
option(GGML_MUSA_MUDNN_COPY "ggml: enable MUDNN for accelerated copy" OFF)
178179
option(GGML_VULKAN "ggml: use Vulkan" OFF)
179180
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
180181
option(GGML_VULKAN_DEBUG "ggml: enable Vulkan debug output" OFF)

ggml/src/ggml-cuda/cpy.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#include "cpy.cuh"
22
#include "dequantize.cuh"
3-
#ifdef GGML_USE_MUSA
3+
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
44
#include "ggml-musa/mudnn.cuh"
5-
#endif // GGML_USE_MUSA
5+
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
66

77
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
88

@@ -600,7 +600,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
600600
#endif
601601
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
602602
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
603-
#ifdef GGML_USE_MUSA
603+
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
604604
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
605605
CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0));
606606
} else

ggml/src/ggml-musa/CMakeLists.txt

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,12 @@ if (MUSAToolkit_FOUND)
3434
list(APPEND GGML_SOURCES_MUSA ${SRCS})
3535
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
3636
list(APPEND GGML_SOURCES_MUSA ${SRCS})
37-
file(GLOB SRCS "../ggml-musa/*.cu")
38-
list(APPEND GGML_SOURCES_MUSA ${SRCS})
37+
38+
if (GGML_MUSA_MUDNN_COPY)
39+
file(GLOB SRCS "../ggml-musa/*.cu")
40+
list(APPEND GGML_SOURCES_MUSA ${SRCS})
41+
add_compile_definitions(GGML_MUSA_MUDNN_COPY)
42+
endif()
3943

4044
if (GGML_CUDA_FA_ALL_QUANTS)
4145
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
@@ -101,10 +105,16 @@ if (MUSAToolkit_FOUND)
101105
endif()
102106

103107
if (GGML_STATIC)
104-
# TODO: mudnn has not provided static libraries yet
105108
target_link_libraries(ggml-musa PRIVATE MUSA::musart_static MUSA::mublas_static)
109+
# TODO: mudnn has not provided static libraries yet
110+
# if (GGML_MUSA_MUDNN_COPY)
111+
# target_link_libraries(ggml-musa PRIVATE mudnn_static)
112+
# endif()
106113
else()
107-
target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas mudnn)
114+
target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas)
115+
if (GGML_MUSA_MUDNN_COPY)
116+
target_link_libraries(ggml-musa PRIVATE mudnn)
117+
endif()
108118
endif()
109119

110120
if (GGML_CUDA_NO_VMM)

0 commit comments

Comments
 (0)