Skip to content

musa: upgrade musa sdk to rc4.2.0 #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ if(MUSAToolkit_FOUND)
install(TARGETS ggml-musa
RUNTIME_DEPENDENCIES
DIRECTORIES ${MUSAToolkit_BIN_DIR} ${MUSAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES mudnn mublas musart musa
PRE_INCLUDE_REGEXES mublas musart musa
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_MUSA_INSTALL_DIR} COMPONENT MUSA
LIBRARY DESTINATION ${OLLAMA_MUSA_INSTALL_DIR} COMPONENT MUSA
Expand Down
15 changes: 12 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ ARG ROCMVERSION=6.3.3
ARG JETPACK5VERSION=r35.4.1
ARG JETPACK6VERSION=r36.4.0
ARG CMAKEVERSION=3.31.2
ARG MUSAVERSION=rc4.0.1
ARG MUSAVERSION=rc4.2.0
ARG UBUNTUVERSION=22.04

# We require gcc v10 minimum. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
Expand Down Expand Up @@ -103,7 +104,7 @@ FROM scratch AS rocm
COPY --from=rocm-6 dist/lib/ollama /lib/ollama

# Moore Threads (MUSA) build stages
FROM mthreads/musa:${MUSAVERSION}-mudnn-devel-ubuntu22.04 AS musa-4
FROM mthreads/musa:${MUSAVERSION}-devel-ubuntu${UBUNTUVERSION}-amd64 AS musa-4
RUN apt-get update \
&& apt-get install -y curl \
&& apt-get clean \
Expand All @@ -117,6 +118,14 @@ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'MUSA 4' \
&& cmake --build --parallel --preset 'MUSA 4' \
&& cmake --install build --component MUSA --strip --parallel 8
# TODO: Remove the following lines in next release
RUN cd dist/lib/ollama/musa_v4 && \
for f in libmusa.so.*; do \
if [ -f "$f" ] && [[ "$f" =~ ^libmusa\.so\.[0-9]+$ ]]; then \
ln -sf "$f" libmusa.so; \
break; \
fi; \
done

FROM scratch AS musa
COPY --from=musa-4 dist/lib/ollama/musa_v4 /lib/ollama/musa_v4
Expand All @@ -125,7 +134,7 @@ FROM ${FLAVOR} AS archive
COPY --from=cpu dist/lib/ollama /lib/ollama
COPY --from=build /bin/ollama /bin/ollama

FROM ubuntu:22.04
FROM ubuntu:${UBUNTUVERSION}
RUN apt-get update \
&& apt-get install -y ca-certificates \
&& apt-get clean \
Expand Down
8 changes: 4 additions & 4 deletions ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include "cpy.cuh"
#include "dequantize.cuh"
#ifdef GGML_USE_MUSA
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
#include "ggml-musa/mudnn.cuh"
#endif // GGML_USE_MUSA
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY

typedef void (*cpy_kernel_t)(const char * cx, char * cdst);

Expand Down Expand Up @@ -645,11 +645,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
#endif
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
#ifdef GGML_USE_MUSA
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0));
} else
#endif // GGML_USE_MUSA
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
{
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
}
Expand Down
4 changes: 2 additions & 2 deletions ml/backend/ggml/ggml/src/ggml-cuda/vendors/musa.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#define CUBLAS_OP_N MUBLAS_OP_N
#define CUBLAS_OP_T MUBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_DEFAULT_MATH
#define CUDA_R_16F MUSA_R_16F
#define CUDA_R_16BF MUSA_R_16BF
#define CUDA_R_32F MUSA_R_32F
Expand All @@ -29,7 +29,7 @@
#define cublasSgemm mublasSgemm
#define cublasStatus_t mublasStatus_t
#define cublasOperation_t mublasOperation_t
#define cublasGetStatusString mublasStatus_to_string
#define cublasGetStatusString mublasGetStatusString
#define cudaDataType_t musaDataType_t
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
Expand Down
18 changes: 14 additions & 4 deletions ml/backend/ggml/ggml/src/ggml-musa/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ if (MUSAToolkit_FOUND)
list(APPEND GGML_SOURCES_MUSA ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
list(APPEND GGML_SOURCES_MUSA ${SRCS})
file(GLOB SRCS "../ggml-musa/*.cu")
list(APPEND GGML_SOURCES_MUSA ${SRCS})

if (GGML_MUSA_MUDNN_COPY)
file(GLOB SRCS "../ggml-musa/*.cu")
list(APPEND GGML_SOURCES_MUSA ${SRCS})
add_compile_definitions(GGML_MUSA_MUDNN_COPY)
endif()

if (GGML_CUDA_FA_ALL_QUANTS)
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
Expand Down Expand Up @@ -97,10 +101,16 @@ if (MUSAToolkit_FOUND)
endif()

if (GGML_STATIC)
# TODO: mudnn has not provided static libraries yet
target_link_libraries(ggml-musa PRIVATE MUSA::musart_static MUSA::mublas_static)
# TODO: mudnn has not provided static libraries yet
# if (GGML_MUSA_MUDNN_COPY)
# target_link_libraries(ggml-musa PRIVATE mudnn_static)
# endif()
else()
target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas mudnn)
target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas)
if (GGML_MUSA_MUDNN_COPY)
target_link_libraries(ggml-musa PRIVATE mudnn)
endif()
endif()

if (GGML_CUDA_NO_VMM)
Expand Down