diff --git a/README.md b/README.md index 76ad6ebd..63728d21 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin

🔥 Updates

+* **Mar 15, 2025**: Support ROCm on AMD GPU ([Tutorial](./doc/en/ROCm.md)). * **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and [IQ1_S/FP8 hybrid](./doc/en/fp8_kernel.md) weights. Support 139K [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context) for DeepSeek-V3 and R1 in 24GB VRAM. * **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context). * **Feb 15, 2025**: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed (+15%, up to 16 Tokens/s), update [docs](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/). diff --git a/doc/README.md b/doc/README.md index 8bd94a00..f0683a30 100644 --- a/doc/README.md +++ b/doc/README.md @@ -22,6 +22,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin

🔥 Updates

+* **Mar 15, 2025**: Support ROCm on AMD GPU ([Tutorial](./doc/en/ROCm.md)). * **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and [IQ1_S/FP8 hybrid](./doc/en/fp8_kernel.md) weights. Support 139K [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context) for DeepSeek-V3 and R1 in 24GB VRAM. * **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context). * **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. The detailed tutorial is [here](./en/DeepseekR1_V3_tutorial.md). diff --git a/doc/SUMMARY.md b/doc/SUMMARY.md index d9fa9b85..854549c5 100644 --- a/doc/SUMMARY.md +++ b/doc/SUMMARY.md @@ -10,6 +10,7 @@ - [Injection Tutorial](en/injection_tutorial.md) - [Multi-GPU Tutorial](en/multi-gpu-tutorial.md) - [Use FP8 GPU Kernel](en/fp8_kernel.md) +- [Use AMD GPU](en/ROCm.md) # Server - [Server](en/api/server/server.md) - [Website](en/api/server/website.md) diff --git a/doc/en/ROCm.md b/doc/en/ROCm.md new file mode 100644 index 00000000..39f48902 --- /dev/null +++ b/doc/en/ROCm.md @@ -0,0 +1,96 @@ +# ROCm Support for ktransformers (Beta) + +## Introduction + +### Overview +In our effort to expand GPU architecture support beyond NVIDIA, we are excited to introduce **AMD GPU support through ROCm** in ktransformers (Beta release). This implementation has been tested and developed using EPYC 9274F processors and AMD Radeon 7900xtx GPUs. + +## Installation Guide + +### 1. Install ROCm Driver +Begin by installing the ROCm drivers for your AMD GPU: +- [Official ROCm Installation Guide for Radeon GPUs](https://rocm.docs.amd.com/projects/radeon/en/latest/docs/install/native_linux/install-radeon.html) + +### 2. Set Up Conda Environment +We recommend using Miniconda3/Anaconda3 for environment management: + +```bash +# Download Miniconda +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh + +# Create environment +conda create --name ktransformers python=3.11 +conda activate ktransformers + +# Install required libraries +conda install -c conda-forge libstdcxx-ng + +# Verify GLIBCXX version (should include 3.4.32) +strings ~/anaconda3/envs/ktransformers/lib/libstdc++.so.6 | grep GLIBCXX +``` + +> **Note:** Adjust the Anaconda path if your installation directory differs from `~/anaconda3` + +### 3. Install PyTorch for ROCm +Install PyTorch with ROCm 6.2.4 support: + +```bash +pip3 install torch torchvision torchaudio \ + --index-url https://download.pytorch.org/whl/rocm6.2.4 +pip3 install packaging ninja cpufeature numpy +``` + +> **Tip:** For other ROCm versions, visit [PyTorch Previous Versions](https://pytorch.org/get-started/previous-versions/) + +### 4. Build ktransformers + +```bash +# Clone repository +git clone https://github.com/kvcache-ai/ktransformers.git +cd ktransformers +git submodule update --init + +# Optional: Compile web interface +# See: api/server/website.md + +# Install dependencies +bash install.sh +``` + +## Running DeepSeek-R1 Models + +### Configuration for 24GB VRAM GPUs +Use our optimized configuration for constrained VRAM: + +```bash +python ktransformers/local_chat.py \ + --model_path deepseek-ai/DeepSeek-R1 \ + --gguf_path \ + --optimize_config_path ktransformers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml \ + --cpu_infer +``` + +> **Beta Note:** Current Q8 linear implementation (Marlin alternative) shows suboptimal performance. Expect optimizations in future releases. + +### Configuration for 40GB+ VRAM GPUs +For better performance on high-VRAM GPUs: + +1. Modify `DeepSeek-V3-Chat.yaml`: + ```yaml + # Replace all instances of: + KLinearMarlin → KLinearTorch + ``` + +2. Execute with: + ```bash + python ktransformers/local_chat.py \ + --model_path deepseek-ai/DeepSeek-R1 \ + --gguf_path \ + --optimize_config_path \ + --cpu_infer + ``` +> **Tip:** If you got 2 * 24GB AMD GPUS, you may also do the same modify and run `ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml` instead. + +## Known Limitations +- Marlin operations not supported on ROCm platform +- Current Q8 linear implementation shows reduced performance (Beta limitation) diff --git a/ktransformers/__init__.py b/ktransformers/__init__.py index b100dcb1..a8bce455 100644 --- a/ktransformers/__init__.py +++ b/ktransformers/__init__.py @@ -8,4 +8,4 @@ LastEditors : chenxl LastEditTime : 2025-02-15 03:53:02 ''' -__version__ = "0.2.3.post1" +__version__ = "0.2.3post2" diff --git a/ktransformers/ktransformers_ext/CMakeLists.txt b/ktransformers/ktransformers_ext/CMakeLists.txt index 08e578db..eefcadf0 100644 --- a/ktransformers/ktransformers_ext/CMakeLists.txt +++ b/ktransformers/ktransformers_ext/CMakeLists.txt @@ -32,6 +32,7 @@ endif() option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF) option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" OFF) option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF) +option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF) # Architecture specific # TODO: probably these flags need to be tweaked on some architectures @@ -201,6 +202,31 @@ endif() # message(STATUS "Can't found CUDA lib") # endif() +if (NOT EXISTS $ENV{ROCM_PATH}) + if (NOT EXISTS /opt/rocm) + set(ROCM_PATH /usr) + else() + set(ROCM_PATH /opt/rocm) + endif() +else() + set(ROCM_PATH $ENV{ROCM_PATH}) +endif() + +list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH}) +list(APPEND CMAKE_PREFIX_PATH "${ROCM_PATH}/lib64/cmake") + +if (NOT EXISTS $ENV{MUSA_PATH}) + if (NOT EXISTS /opt/musa) + set(MUSA_PATH /usr/local/musa) + else() + set(MUSA_PATH /opt/musa) + endif() +else() + set(MUSA_PATH $ENV{MUSA_PATH}) +endif() + +list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake") + add_compile_options("$<$:${ARCH_FLAGS}>") add_compile_options("$<$:${ARCH_FLAGS}>") @@ -218,6 +244,14 @@ elseif (UNIX) add_compile_definitions(KTRANSFORMERS_USE_CUDA=1) endif() + if (KTRANSFORMERS_USE_ROCM) + find_package(HIP REQUIRED) + if(HIP_FOUND) + include_directories("${HIP_INCLUDE_DIRS}") + add_compile_definitions(KTRANSFORMERS_USE_ROCM=1) + endif() + endif() + if (KTRANSFORMERS_USE_MUSA) if (NOT EXISTS $ENV{MUSA_PATH}) if (NOT EXISTS /opt/musa) @@ -258,6 +292,11 @@ elseif(UNIX) endif() target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so") endif() + if (KTRANSFORMERS_USE_ROCM) + add_compile_definitions(USE_HIP=1) + target_link_libraries(${PROJECT_NAME} PRIVATE "${ROCM_PATH}/lib/libamdhip64.so") + message(STATUS "Building for HIP") + endif() if(KTRANSFORMERS_USE_MUSA) target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart) endif() diff --git a/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h b/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h index d0f0c603..d0f7b11b 100644 --- a/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h +++ b/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h @@ -7,79 +7,83 @@ * @LastEditTime : 2024-08-07 09:47:43 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ -#ifndef CPUINFER_CPUINFER_H -#define CPUINFER_CPUINFER_H - -#include -#include -#include -#include -#include -#include -#include -#ifdef KTRANSFORMERS_USE_CUDA -#include "vendors/cuda.h" -#elif KTRANSFORMERS_USE_MUSA -#include "vendors/musa.h" -#endif - -#include "backend.h" -#include "task_queue.h" - -#include "llama.cpp/ggml-impl.h" - -class CPUInfer { - public: - CPUInfer(int thread_num) { - backend_ = new Backend(thread_num - 1); - task_queue_ = new TaskQueue(); - for (int i = 0; i < (1 << 16); ++i) { - ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(i); - } - } - - ~CPUInfer() { - delete backend_; - delete task_queue_; - } - - template - void enqueue(Func f, Obj* obj, Args... args) { - task_queue_->enqueue([=]() { - std::invoke(f, *obj, args..., backend_); - }); - } - - void submit(std::pair params) { - void (*func)(void*) = (void (*)(void*))params.first; - void* args = (void*)params.second; - *((CPUInfer**)args) = this; - func(args); - } - - void sync() { - task_queue_->sync(); - } - - void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair params) { - void (*func)(void*) = (void (*)(void*))params.first; - void* args = (void*)params.second; - *((CPUInfer**)args) = this; - cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args); - } - - static void sync_(void* cpu_infer_ptr) { - CPUInfer* cpuinfer = (CPUInfer*)cpu_infer_ptr; - cpuinfer->sync(); - } - - void sync_with_cuda_stream(intptr_t user_cuda_stream) { - cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this); - } - - public: - Backend* backend_; - TaskQueue* task_queue_; -}; - -#endif \ No newline at end of file + #ifndef CPUINFER_CPUINFER_H + #define CPUINFER_CPUINFER_H + + #include + #include + #include + #include + #include + #include + #include + #ifdef KTRANSFORMERS_USE_CUDA + #include "vendors/cuda.h" + #elif KTRANSFORMERS_USE_MUSA + #include "vendors/musa.h" + #elif KTRANSFORMERS_USE_ROCM + #define __HIP_PLATFORM_AMD__ + #include "vendors/hip.h" + #endif + + #include "backend.h" + #include "task_queue.h" + #include "../vendors/vendor.h" + + #include "llama.cpp/ggml-impl.h" + + class CPUInfer { + public: + CPUInfer(int thread_num) { + backend_ = new Backend(thread_num - 1); + task_queue_ = new TaskQueue(); + for (int i = 0; i < (1 << 16); ++i) { + ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(i); + } + } + + ~CPUInfer() { + delete backend_; + delete task_queue_; + } + + template + void enqueue(Func f, Obj* obj, Args... args) { + task_queue_->enqueue([=]() { + std::invoke(f, *obj, args..., backend_); + }); + } + + void submit(std::pair params) { + void (*func)(void*) = (void (*)(void*))params.first; + void* args = (void*)params.second; + *((CPUInfer**)args) = this; + func(args); + } + + void sync() { + task_queue_->sync(); + } + + void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair params) { + void (*func)(void*) = (void (*)(void*))params.first; + void* args = (void*)params.second; + *((CPUInfer**)args) = this; + cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args); + } + + static void sync_(void* cpu_infer_ptr) { + CPUInfer* cpuinfer = (CPUInfer*)cpu_infer_ptr; + cpuinfer->sync(); + } + + void sync_with_cuda_stream(intptr_t user_cuda_stream) { + cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this); + } + + public: + Backend* backend_; + TaskQueue* task_queue_; + }; + + #endif \ No newline at end of file diff --git a/ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h b/ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h index 082ad2c3..1746b073 100644 --- a/ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h +++ b/ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h @@ -1,3 +1,15 @@ #pragma once -#include \ No newline at end of file +#include +#include +#include +#include +#include + +#if CUDART_VERSION < 11020 +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED +#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH +#define CUBLAS_COMPUTE_16F CUDA_R_16F +#define CUBLAS_COMPUTE_32F CUDA_R_32F +#define cublasComputeType_t cudaDataType_t +#endif // CUDART_VERSION < 11020 diff --git a/ktransformers/ktransformers_ext/cpu_backend/vendors/hip.h b/ktransformers/ktransformers_ext/cpu_backend/vendors/hip.h new file mode 100644 index 00000000..abbc1e89 --- /dev/null +++ b/ktransformers/ktransformers_ext/cpu_backend/vendors/hip.h @@ -0,0 +1,172 @@ +#pragma once + +#define HIP_ENABLE_WARP_SYNC_BUILTINS 1 +#include +#include +#include +#include +#ifdef __HIP_PLATFORM_AMD__ +// for rocblas_initialize() +#include "rocblas/rocblas.h" +#endif // __HIP_PLATFORM_AMD__ + +#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F +#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F +#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F +#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT +#define CUBLAS_OP_N HIPBLAS_OP_N +#define CUBLAS_OP_T HIPBLAS_OP_T +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define CUBLAS_TF32_TENSOR_OP_MATH 0 +#define CUDA_R_16F HIPBLAS_R_16F +#define CUDA_R_32F HIPBLAS_R_32F +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported +#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended +#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned +#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice +#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite +#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} +#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) +#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) +#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6 +#define cublasCreate hipblasCreate +#define cublasDestroy hipblasDestroy +#define cublasGemmEx hipblasGemmEx +#define cublasGemmBatchedEx hipblasGemmBatchedEx +#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx +#define cublasHandle_t hipblasHandle_t +#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS +#define cublasSetStream hipblasSetStream +#define cublasSgemm hipblasSgemm +#define cublasStatus_t hipblasStatus_t +#define cublasOperation_t hipblasOperation_t +#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6 +#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer +#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess +#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess +#define cudaDeviceProp hipDeviceProp_t +#define cudaDeviceSynchronize hipDeviceSynchronize +#define cudaError_t hipError_t +#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled +#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled +#define cudaEventCreateWithFlags hipEventCreateWithFlags +#define cudaEventDisableTiming hipEventDisableTiming +#define cudaEventRecord hipEventRecord +#define cudaEventSynchronize hipEventSynchronize +#define cudaEvent_t hipEvent_t +#define cudaEventDestroy hipEventDestroy +#define cudaFree hipFree +#define cudaFreeHost hipHostFree +#define cudaGetDevice hipGetDevice +#define cudaGetDeviceCount hipGetDeviceCount +#define cudaGetDeviceProperties hipGetDeviceProperties +#define cudaGetErrorString hipGetErrorString +#define cudaGetLastError hipGetLastError +#define cudaHostRegister hipHostRegister +#define cudaHostRegisterPortable hipHostRegisterPortable +#define cudaHostRegisterReadOnly hipHostRegisterReadOnly +#define cudaHostUnregister hipHostUnregister +#define cudaLaunchHostFunc hipLaunchHostFunc +#define cudaMalloc hipMalloc +#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) +#define cudaMemcpy hipMemcpy +#define cudaMemcpyAsync hipMemcpyAsync +#define cudaMemcpyPeerAsync hipMemcpyPeerAsync +#define cudaMemcpy2DAsync hipMemcpy2DAsync +#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cudaMemcpyKind hipMemcpyKind +#define cudaMemset hipMemset +#define cudaMemsetAsync hipMemsetAsync +#define cudaMemGetInfo hipMemGetInfo +#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize +#define cudaSetDevice hipSetDevice +#define cuDeviceGet hipDeviceGet +#define CUdevice hipDevice_t +#define CUdeviceptr hipDeviceptr_t +#define cuMemUnmap hipMemUnmap +#define CUmemAccessDesc hipMemAccessDesc +#define cuMemAddressFree hipMemAddressFree +#define cuMemRelease hipMemRelease +#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t +#define cuMemCreate hipMemCreate +#define cuMemAddressReserve hipMemAddressReserve +#define cuMemMap hipMemMap +#define cuMemSetAccess hipMemSetAccess +#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity +#define CUmemAllocationProp hipMemAllocationProp +#define cuDeviceGetAttribute hipDeviceGetAttribute +#define cudaStreamCreateWithFlags hipStreamCreateWithFlags +#define cudaStreamDestroy hipStreamDestroy +#define cudaStreamFireAndForget hipStreamFireAndForget +#define cudaStreamNonBlocking hipStreamNonBlocking +#define cudaStreamPerThread hipStreamPerThread +#define cudaStreamSynchronize hipStreamSynchronize +#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) +#define cudaGraphExec_t hipGraphExec_t +#define cudaGraphNode_t hipGraphNode_t +#define cudaKernelNodeParams hipKernelNodeParams +#define cudaKernelNodeParams hipKernelNodeParams +#define cudaGraphExecDestroy hipGraphExecDestroy +#define cudaGraphLaunch hipGraphLaunch +#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure +#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult +#define cudaGraphNodeType hipGraphNodeType +#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel +#define cudaGraphInstantiate hipGraphInstantiate +#define cudaStreamEndCapture hipStreamEndCapture +#define cudaGraphDestroy hipGraphDestroy +#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams +#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction +#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams +#define cudaGraphNodeGetType hipGraphNodeGetType +#define cudaGraphGetNodes hipGraphGetNodes +#define cudaGraphExecUpdate hipGraphExecUpdate +#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed +#define cudaStreamBeginCapture hipStreamBeginCapture +#define cudaGraph_t hipGraph_t +#define cudaStream_t hipStream_t +#define cudaSuccess hipSuccess +#define cudaHostFn_t hipHostFn_t +#define __trap() do { abort(); __builtin_unreachable(); } while(0) +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED +#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED +#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE +#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH +#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR +#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED +#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR +#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED + +#define __CUDA_ARCH__ 1300 + +#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) +#define GCN +#endif + +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) +#define CDNA +#endif + +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \ + defined(__gfx1150__) || defined(__gfx1151__) +#define RDNA3 +#endif + +#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \ + defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__) +#define RDNA2 +#endif + +#if defined(__gfx1010__) || defined(__gfx1012__) +#define RDNA1 +#endif + +#ifndef __has_builtin + #define __has_builtin(x) 0 +#endif + +typedef hip_bfloat16 nv_bfloat16; diff --git a/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h b/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h index 18922218..6cc1b69e 100644 --- a/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h +++ b/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h @@ -1,9 +1,137 @@ #pragma once #include +#include +#include #include - +#include +#define CUBLAS_COMPUTE_16F CUDA_R_16F +#define CUBLAS_COMPUTE_32F CUDA_R_32F +#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F +#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT +#define CUBLAS_OP_N MUBLAS_OP_N +#define CUBLAS_OP_T MUBLAS_OP_T +#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS +#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT +#define CUDA_R_16F MUSA_R_16F +#define CUDA_R_32F MUSA_R_32F +#define cublasComputeType_t cudaDataType_t +#define cublasCreate mublasCreate +#define cublasDestroy mublasDestroy +#define cublasGemmEx mublasGemmEx +#define cublasGemmBatchedEx mublasGemmBatchedEx +#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx +#define cublasHandle_t mublasHandle_t +#define cublasSetMathMode mublasSetMathMode +#define cublasSetStream mublasSetStream +#define cublasSgemm mublasSgemm +#define cublasStatus_t mublasStatus_t +#define cublasOperation_t mublasOperation_t +#define cublasGetStatusString mublasStatus_to_string +#define cudaDataType_t musaDataType_t +#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer +#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess +#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess +#define cudaDeviceProp musaDeviceProp +#define cudaDeviceSynchronize musaDeviceSynchronize +#define cudaError_t musaError_t +#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled +#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled +#define cudaEventCreateWithFlags musaEventCreateWithFlags +#define cudaEventDisableTiming musaEventDisableTiming +#define cudaEventRecord musaEventRecord +#define cudaEventSynchronize musaEventSynchronize +#define cudaEvent_t musaEvent_t +#define cudaEventDestroy musaEventDestroy +#define cudaFree musaFree +#define cudaFreeHost musaFreeHost +#define cudaGetDevice musaGetDevice +#define cudaGetDeviceCount musaGetDeviceCount +#define cudaGetDeviceProperties musaGetDeviceProperties +#define cudaGetErrorString musaGetErrorString +#define cudaGetLastError musaGetLastError +#define cudaHostRegister musaHostRegister +#define cudaHostRegisterPortable musaHostRegisterPortable +#define cudaHostRegisterReadOnly musaHostRegisterReadOnly +#define cudaHostUnregister musaHostUnregister #define cudaLaunchHostFunc musaLaunchHostFunc +#define cudaMalloc musaMalloc +#define cudaMallocHost musaMallocHost +#define cudaMallocManaged musaMallocManaged +#define cudaMemcpy musaMemcpy +#define cudaMemcpyAsync musaMemcpyAsync +#define cudaMemcpyPeerAsync musaMemcpyPeerAsync +#define cudaMemcpy2DAsync musaMemcpy2DAsync +#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice +#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost +#define cudaMemcpyHostToDevice musaMemcpyHostToDevice +#define cudaMemcpyKind musaMemcpyKind +#define cudaMemset musaMemset +#define cudaMemsetAsync musaMemsetAsync +#define cudaMemGetInfo musaMemGetInfo +#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize +#define cudaSetDevice musaSetDevice +#define cudaStreamCreateWithFlags musaStreamCreateWithFlags +#define cudaStreamDestroy musaStreamDestroy +#define cudaStreamFireAndForget musaStreamFireAndForget +#define cudaStreamNonBlocking musaStreamNonBlocking +#define cudaStreamPerThread musaStreamPerThread +#define cudaStreamSynchronize musaStreamSynchronize +#define cudaStreamWaitEvent musaStreamWaitEvent #define cudaStream_t musaStream_t -#define cudaHostFn_t musaHostFn_t -#define nv_bfloat16 mt_bfloat16 \ No newline at end of file +#define cudaSuccess musaSuccess + +// Additional mappings for MUSA virtual memory pool +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED +#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE +#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED +#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED +#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE +#define CUdevice MUdevice +#define CUdeviceptr MUdeviceptr +#define CUmemAccessDesc MUmemAccessDesc +#define CUmemAllocationProp MUmemAllocationProp +#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle +#define cuDeviceGet muDeviceGet +#define cuDeviceGetAttribute muDeviceGetAttribute +#define cuMemAddressFree muMemAddressFree +#define cuMemAddressReserve muMemAddressReserve +#define cuMemCreate muMemCreate +#define cuMemGetAllocationGranularity muMemGetAllocationGranularity +#define cuMemMap muMemMap +#define cuMemRelease muMemRelease +#define cuMemSetAccess muMemSetAccess +#define cuMemUnmap muMemUnmap +#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize +#define cudaFuncSetAttribute musaFuncSetAttribute +#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms +#define make_cudaExtent make_musaExtent +#define make_cudaPitchedPtr make_musaPitchedPtr + +// Additional mappings for MUSA graphs +#define CUDA_SUCCESS MUSA_SUCCESS +#define CUresult MUresult +#define cuGetErrorString muGetErrorString +#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure +#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction +#define cudaGraphDestroy musaGraphDestroy +#define cudaGraphExecDestroy musaGraphExecDestroy +#define cudaGraphExec_t musaGraphExec_t +#define cudaGraphExecUpdate musaGraphExecUpdate +#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult +#define cudaGraphGetNodes musaGraphGetNodes +#define cudaGraphInstantiate musaGraphInstantiate +#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams +#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams +#define cudaGraphLaunch musaGraphLaunch +#define cudaGraphNodeGetType musaGraphNodeGetType +#define cudaGraphNode_t musaGraphNode_t +#define cudaGraphNodeType musaGraphNodeType +#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel +#define cudaGraph_t musaGraph_t +#define cudaKernelNodeParams musaKernelNodeParams +#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed +#define cudaStreamEndCapture musaStreamEndCapture + +typedef mt_bfloat16 nv_bfloat16; diff --git a/ktransformers/ktransformers_ext/cpu_backend/vendors/vendor.h b/ktransformers/ktransformers_ext/cpu_backend/vendors/vendor.h new file mode 100644 index 00000000..84704389 --- /dev/null +++ b/ktransformers/ktransformers_ext/cpu_backend/vendors/vendor.h @@ -0,0 +1,13 @@ +#ifndef CPUINFER_VENDOR_VENDOR_H +#define CPUINFER_VENDOR_VENDOR_H + +#ifdef USE_CUDA +#include "cuda.h" +#elif USE_HIP +#define __HIP_PLATFORM_AMD__ +#include "hip.h" +#elif USE_MUSA +#include "musa.h" +#endif + +#endif // CPUINFER_VENDOR_VENDOR_H \ No newline at end of file diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu b/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu index e80efc45..23630356 100644 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu +++ b/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu @@ -16,6 +16,10 @@ #include #include +#ifdef KTRANSFORMERS_USE_ROCM +typedef hip_bfloat16 nv_bfloat16; +#endif + __global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){ diff --git a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu index 54e538ae..87f4581b 100644 --- a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu +++ b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu @@ -36,7 +36,7 @@ inline std::string str(T x) { namespace gptq_marlin { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined(__HIP_PLATFORM_AMD__) __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, diff --git a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh index 66a59203..ccf9cfd8 100644 --- a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh +++ b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh @@ -39,7 +39,7 @@ using I4 = Vec; constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined (__HIP_PLATFORM_AMD__) // No support for async #else diff --git a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh index b8babfb0..80f6ea43 100644 --- a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh +++ b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh @@ -8,6 +8,11 @@ #include #include +#ifdef __HIP_PLATFORM_AMD__ +typedef __hip_bfloat16 nv_bfloat16; +typedef __hip_bfloat162 nv_bfloat162; +#endif + namespace gptq_marlin { template diff --git a/ktransformers/ktransformers_ext/ext_bindings.cpp b/ktransformers/ktransformers_ext/ext_bindings.cpp index 902d4271..0078a79e 100644 --- a/ktransformers/ktransformers_ext/ext_bindings.cpp +++ b/ktransformers/ktransformers_ext/ext_bindings.cpp @@ -9,7 +9,6 @@ **/ // Python bindings #include "cpu_backend/cpuinfer.h" -#include "device_launch_parameters.h" #include "llamafile/flags.h" #include "operators/kvcache/kvcache.h" #include "operators/llamafile/linear.h" diff --git a/ktransformers/ktransformers_ext/vendors/cuda.h b/ktransformers/ktransformers_ext/vendors/cuda.h new file mode 100644 index 00000000..1746b073 --- /dev/null +++ b/ktransformers/ktransformers_ext/vendors/cuda.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include +#include +#include +#include + +#if CUDART_VERSION < 11020 +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED +#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH +#define CUBLAS_COMPUTE_16F CUDA_R_16F +#define CUBLAS_COMPUTE_32F CUDA_R_32F +#define cublasComputeType_t cudaDataType_t +#endif // CUDART_VERSION < 11020 diff --git a/ktransformers/ktransformers_ext/vendors/hip.h b/ktransformers/ktransformers_ext/vendors/hip.h new file mode 100644 index 00000000..abbc1e89 --- /dev/null +++ b/ktransformers/ktransformers_ext/vendors/hip.h @@ -0,0 +1,172 @@ +#pragma once + +#define HIP_ENABLE_WARP_SYNC_BUILTINS 1 +#include +#include +#include +#include +#ifdef __HIP_PLATFORM_AMD__ +// for rocblas_initialize() +#include "rocblas/rocblas.h" +#endif // __HIP_PLATFORM_AMD__ + +#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F +#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F +#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F +#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT +#define CUBLAS_OP_N HIPBLAS_OP_N +#define CUBLAS_OP_T HIPBLAS_OP_T +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define CUBLAS_TF32_TENSOR_OP_MATH 0 +#define CUDA_R_16F HIPBLAS_R_16F +#define CUDA_R_32F HIPBLAS_R_32F +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported +#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended +#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned +#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice +#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite +#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} +#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) +#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) +#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6 +#define cublasCreate hipblasCreate +#define cublasDestroy hipblasDestroy +#define cublasGemmEx hipblasGemmEx +#define cublasGemmBatchedEx hipblasGemmBatchedEx +#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx +#define cublasHandle_t hipblasHandle_t +#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS +#define cublasSetStream hipblasSetStream +#define cublasSgemm hipblasSgemm +#define cublasStatus_t hipblasStatus_t +#define cublasOperation_t hipblasOperation_t +#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6 +#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer +#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess +#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess +#define cudaDeviceProp hipDeviceProp_t +#define cudaDeviceSynchronize hipDeviceSynchronize +#define cudaError_t hipError_t +#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled +#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled +#define cudaEventCreateWithFlags hipEventCreateWithFlags +#define cudaEventDisableTiming hipEventDisableTiming +#define cudaEventRecord hipEventRecord +#define cudaEventSynchronize hipEventSynchronize +#define cudaEvent_t hipEvent_t +#define cudaEventDestroy hipEventDestroy +#define cudaFree hipFree +#define cudaFreeHost hipHostFree +#define cudaGetDevice hipGetDevice +#define cudaGetDeviceCount hipGetDeviceCount +#define cudaGetDeviceProperties hipGetDeviceProperties +#define cudaGetErrorString hipGetErrorString +#define cudaGetLastError hipGetLastError +#define cudaHostRegister hipHostRegister +#define cudaHostRegisterPortable hipHostRegisterPortable +#define cudaHostRegisterReadOnly hipHostRegisterReadOnly +#define cudaHostUnregister hipHostUnregister +#define cudaLaunchHostFunc hipLaunchHostFunc +#define cudaMalloc hipMalloc +#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) +#define cudaMemcpy hipMemcpy +#define cudaMemcpyAsync hipMemcpyAsync +#define cudaMemcpyPeerAsync hipMemcpyPeerAsync +#define cudaMemcpy2DAsync hipMemcpy2DAsync +#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cudaMemcpyKind hipMemcpyKind +#define cudaMemset hipMemset +#define cudaMemsetAsync hipMemsetAsync +#define cudaMemGetInfo hipMemGetInfo +#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize +#define cudaSetDevice hipSetDevice +#define cuDeviceGet hipDeviceGet +#define CUdevice hipDevice_t +#define CUdeviceptr hipDeviceptr_t +#define cuMemUnmap hipMemUnmap +#define CUmemAccessDesc hipMemAccessDesc +#define cuMemAddressFree hipMemAddressFree +#define cuMemRelease hipMemRelease +#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t +#define cuMemCreate hipMemCreate +#define cuMemAddressReserve hipMemAddressReserve +#define cuMemMap hipMemMap +#define cuMemSetAccess hipMemSetAccess +#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity +#define CUmemAllocationProp hipMemAllocationProp +#define cuDeviceGetAttribute hipDeviceGetAttribute +#define cudaStreamCreateWithFlags hipStreamCreateWithFlags +#define cudaStreamDestroy hipStreamDestroy +#define cudaStreamFireAndForget hipStreamFireAndForget +#define cudaStreamNonBlocking hipStreamNonBlocking +#define cudaStreamPerThread hipStreamPerThread +#define cudaStreamSynchronize hipStreamSynchronize +#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) +#define cudaGraphExec_t hipGraphExec_t +#define cudaGraphNode_t hipGraphNode_t +#define cudaKernelNodeParams hipKernelNodeParams +#define cudaKernelNodeParams hipKernelNodeParams +#define cudaGraphExecDestroy hipGraphExecDestroy +#define cudaGraphLaunch hipGraphLaunch +#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure +#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult +#define cudaGraphNodeType hipGraphNodeType +#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel +#define cudaGraphInstantiate hipGraphInstantiate +#define cudaStreamEndCapture hipStreamEndCapture +#define cudaGraphDestroy hipGraphDestroy +#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams +#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction +#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams +#define cudaGraphNodeGetType hipGraphNodeGetType +#define cudaGraphGetNodes hipGraphGetNodes +#define cudaGraphExecUpdate hipGraphExecUpdate +#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed +#define cudaStreamBeginCapture hipStreamBeginCapture +#define cudaGraph_t hipGraph_t +#define cudaStream_t hipStream_t +#define cudaSuccess hipSuccess +#define cudaHostFn_t hipHostFn_t +#define __trap() do { abort(); __builtin_unreachable(); } while(0) +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED +#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED +#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE +#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH +#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR +#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED +#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR +#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED + +#define __CUDA_ARCH__ 1300 + +#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) +#define GCN +#endif + +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) +#define CDNA +#endif + +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \ + defined(__gfx1150__) || defined(__gfx1151__) +#define RDNA3 +#endif + +#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \ + defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__) +#define RDNA2 +#endif + +#if defined(__gfx1010__) || defined(__gfx1012__) +#define RDNA1 +#endif + +#ifndef __has_builtin + #define __has_builtin(x) 0 +#endif + +typedef hip_bfloat16 nv_bfloat16; diff --git a/ktransformers/ktransformers_ext/vendors/musa.h b/ktransformers/ktransformers_ext/vendors/musa.h new file mode 100644 index 00000000..6cc1b69e --- /dev/null +++ b/ktransformers/ktransformers_ext/vendors/musa.h @@ -0,0 +1,137 @@ +#pragma once + +#include +#include +#include +#include +#include +#define CUBLAS_COMPUTE_16F CUDA_R_16F +#define CUBLAS_COMPUTE_32F CUDA_R_32F +#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F +#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT +#define CUBLAS_OP_N MUBLAS_OP_N +#define CUBLAS_OP_T MUBLAS_OP_T +#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS +#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT +#define CUDA_R_16F MUSA_R_16F +#define CUDA_R_32F MUSA_R_32F +#define cublasComputeType_t cudaDataType_t +#define cublasCreate mublasCreate +#define cublasDestroy mublasDestroy +#define cublasGemmEx mublasGemmEx +#define cublasGemmBatchedEx mublasGemmBatchedEx +#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx +#define cublasHandle_t mublasHandle_t +#define cublasSetMathMode mublasSetMathMode +#define cublasSetStream mublasSetStream +#define cublasSgemm mublasSgemm +#define cublasStatus_t mublasStatus_t +#define cublasOperation_t mublasOperation_t +#define cublasGetStatusString mublasStatus_to_string +#define cudaDataType_t musaDataType_t +#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer +#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess +#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess +#define cudaDeviceProp musaDeviceProp +#define cudaDeviceSynchronize musaDeviceSynchronize +#define cudaError_t musaError_t +#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled +#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled +#define cudaEventCreateWithFlags musaEventCreateWithFlags +#define cudaEventDisableTiming musaEventDisableTiming +#define cudaEventRecord musaEventRecord +#define cudaEventSynchronize musaEventSynchronize +#define cudaEvent_t musaEvent_t +#define cudaEventDestroy musaEventDestroy +#define cudaFree musaFree +#define cudaFreeHost musaFreeHost +#define cudaGetDevice musaGetDevice +#define cudaGetDeviceCount musaGetDeviceCount +#define cudaGetDeviceProperties musaGetDeviceProperties +#define cudaGetErrorString musaGetErrorString +#define cudaGetLastError musaGetLastError +#define cudaHostRegister musaHostRegister +#define cudaHostRegisterPortable musaHostRegisterPortable +#define cudaHostRegisterReadOnly musaHostRegisterReadOnly +#define cudaHostUnregister musaHostUnregister +#define cudaLaunchHostFunc musaLaunchHostFunc +#define cudaMalloc musaMalloc +#define cudaMallocHost musaMallocHost +#define cudaMallocManaged musaMallocManaged +#define cudaMemcpy musaMemcpy +#define cudaMemcpyAsync musaMemcpyAsync +#define cudaMemcpyPeerAsync musaMemcpyPeerAsync +#define cudaMemcpy2DAsync musaMemcpy2DAsync +#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice +#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost +#define cudaMemcpyHostToDevice musaMemcpyHostToDevice +#define cudaMemcpyKind musaMemcpyKind +#define cudaMemset musaMemset +#define cudaMemsetAsync musaMemsetAsync +#define cudaMemGetInfo musaMemGetInfo +#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize +#define cudaSetDevice musaSetDevice +#define cudaStreamCreateWithFlags musaStreamCreateWithFlags +#define cudaStreamDestroy musaStreamDestroy +#define cudaStreamFireAndForget musaStreamFireAndForget +#define cudaStreamNonBlocking musaStreamNonBlocking +#define cudaStreamPerThread musaStreamPerThread +#define cudaStreamSynchronize musaStreamSynchronize +#define cudaStreamWaitEvent musaStreamWaitEvent +#define cudaStream_t musaStream_t +#define cudaSuccess musaSuccess + +// Additional mappings for MUSA virtual memory pool +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED +#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE +#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED +#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED +#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE +#define CUdevice MUdevice +#define CUdeviceptr MUdeviceptr +#define CUmemAccessDesc MUmemAccessDesc +#define CUmemAllocationProp MUmemAllocationProp +#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle +#define cuDeviceGet muDeviceGet +#define cuDeviceGetAttribute muDeviceGetAttribute +#define cuMemAddressFree muMemAddressFree +#define cuMemAddressReserve muMemAddressReserve +#define cuMemCreate muMemCreate +#define cuMemGetAllocationGranularity muMemGetAllocationGranularity +#define cuMemMap muMemMap +#define cuMemRelease muMemRelease +#define cuMemSetAccess muMemSetAccess +#define cuMemUnmap muMemUnmap +#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize +#define cudaFuncSetAttribute musaFuncSetAttribute +#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms +#define make_cudaExtent make_musaExtent +#define make_cudaPitchedPtr make_musaPitchedPtr + +// Additional mappings for MUSA graphs +#define CUDA_SUCCESS MUSA_SUCCESS +#define CUresult MUresult +#define cuGetErrorString muGetErrorString +#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure +#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction +#define cudaGraphDestroy musaGraphDestroy +#define cudaGraphExecDestroy musaGraphExecDestroy +#define cudaGraphExec_t musaGraphExec_t +#define cudaGraphExecUpdate musaGraphExecUpdate +#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult +#define cudaGraphGetNodes musaGraphGetNodes +#define cudaGraphInstantiate musaGraphInstantiate +#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams +#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams +#define cudaGraphLaunch musaGraphLaunch +#define cudaGraphNodeGetType musaGraphNodeGetType +#define cudaGraphNode_t musaGraphNode_t +#define cudaGraphNodeType musaGraphNodeType +#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel +#define cudaGraph_t musaGraph_t +#define cudaKernelNodeParams musaKernelNodeParams +#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed +#define cudaStreamEndCapture musaStreamEndCapture + +typedef mt_bfloat16 nv_bfloat16; diff --git a/ktransformers/ktransformers_ext/vendors/vendor.h b/ktransformers/ktransformers_ext/vendors/vendor.h new file mode 100644 index 00000000..84704389 --- /dev/null +++ b/ktransformers/ktransformers_ext/vendors/vendor.h @@ -0,0 +1,13 @@ +#ifndef CPUINFER_VENDOR_VENDOR_H +#define CPUINFER_VENDOR_VENDOR_H + +#ifdef USE_CUDA +#include "cuda.h" +#elif USE_HIP +#define __HIP_PLATFORM_AMD__ +#include "hip.h" +#elif USE_MUSA +#include "musa.h" +#endif + +#endif // CPUINFER_VENDOR_VENDOR_H \ No newline at end of file diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index dc5747fa..1a5f55f9 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -31,6 +31,7 @@ from ktransformers.util.utils import prefill_and_generate, get_compute_capability from ktransformers.server.config.config import Config from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled +from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor custom_models = { "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, @@ -169,7 +170,7 @@ def local_chat( assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ "please change max_seq_len in ~/.ktransformers/config.yaml" - if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8: + if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA: generated = prefill_and_generate( model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size, use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index a9bbea6f..2c4197ab 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -20,8 +20,14 @@ import logging from transformers.configuration_utils import PretrainedConfig from transformers.cache_utils import Cache -from flash_attn import flash_attn_func -from ktransformers.operators.triton_attention import decode_attention_fwd_grouped +from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor + +try: + from flash_attn import flash_attn_func +except: + pass +from ktransformers.operators.triton_attention import decode_attention_fwd_grouped +from ktransformers.operators.triton_attention_prefill import context_attention_fwd import os from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled if flashinfer_enabled: @@ -319,18 +325,27 @@ def forward_linux_triton( key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1) value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim) - value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0) - attn_output = flash_attn_func( - query_states, - key_states, - value_states_padded, - softmax_scale=self.softmax_scale, - causal=True, + # for bsz = 1 + attn_output = torch.zeros(bsz * q_len, self.num_heads, self.v_head_dim, device=hidden_states.device) + b_start_loc = torch.zeros(bsz, dtype=torch.int64, device=hidden_states.device) + b_seq_len = torch.full((bsz,), q_len, dtype=torch.int64, device=hidden_states.device) + + max_input_len = q_len + + context_attention_fwd( + q=query_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim), + k=key_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim), + v=value_states.squeeze(0).view(-1, self.num_heads, self.v_head_dim), + o=attn_output, + b_start_loc=b_start_loc, + b_seq_len=b_seq_len, + max_input_len=max_input_len, + is_causal=True ) if self.q_head_dim != self.v_head_dim: - attn_output = attn_output[:, :, :, : self.v_head_dim] + attn_output = attn_output[:, :, : self.v_head_dim] attn_output = attn_output.reshape( bsz, q_len, self.num_heads * self.v_head_dim @@ -589,8 +604,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if os.name == 'nt' or get_compute_capability()<8: - print("for Windows or GPU before ampere, use forward_windows") + if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA: return self.forward_windows( hidden_states, attention_mask, diff --git a/ktransformers/operators/dynamic_attention.py b/ktransformers/operators/dynamic_attention.py index 2d8b1efa..f64e374a 100644 --- a/ktransformers/operators/dynamic_attention.py +++ b/ktransformers/operators/dynamic_attention.py @@ -17,7 +17,10 @@ logger = logging.getLogger("dynamic_attention") sys.path.append(os.path.dirname(__file__) + "/../ktransformers_ext/cpu_backend") from ktransformers.operators.cpuinfer import CPUInfer, CPUInferKVCache -from flash_attn import flash_attn_func, flash_attn_with_kvcache +try: + from flash_attn import flash_attn_func, flash_attn_with_kvcache +except: + print("falsh attn not found") import math diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index ce7e370e..c0dfc851 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -35,6 +35,8 @@ import cpuinfer_ext from ktransformers.operators.cpuinfer import CPUInfer from ktransformers.server.config.config import Config +from typing import Dict, Tuple, Optional, Union +import numpy as np #class KLinearBase(BaseInjectedModule, ABC): class KLinearBase(ABC): @@ -176,16 +178,182 @@ def unload(self): if self.has_bias: self.bias = None + +class KLinearQ8(KLinearBase): + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module = None, + device: str = "cuda", + **kwargs, + ): + super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) + self.has_bias = False + self.compute_dtype = torch.float32 + self.weight = None + self.weight_scale = None + self.weight_zero_point = None + self.bias = None + self.loaded = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + orig_dtype = x.dtype + out_device = x.device + + x = x.to(device=self.device, dtype=self.compute_dtype) + + # 使用原始权重做矩阵乘法,模拟原始行为 + + # 反量化权重进行矩阵乘法 + weight_dequant = self._dequantize_weight(self.weight, self.weight_scale, bits=8) + out = x @ weight_dequant.T + + if self.has_bias: + out = out + self.bias + + return out.to(dtype=orig_dtype, device=out_device) + + def _dequantize_weight(self, q_matrix, scales, bits=8): + """ + Dequantize a low-precision matrix back to floating-point + + Args: + q_matrix (torch.Tensor): Quantized int matrix + scales (torch.Tensor): Scale factors for each column + bits (int): Quantization bits used (8 or 4) + + Returns: + torch.Tensor: Dequantized floating-point matrix + """ + # Ensure inputs are torch tensors + if not isinstance(q_matrix, torch.Tensor): + q_matrix = torch.tensor(q_matrix, dtype=torch.int8) + if not isinstance(scales, torch.Tensor): + scales = torch.tensor(scales, dtype=torch.float32) + + # Convert to correct dtype if needed + if q_matrix.dtype != torch.int8: + q_matrix = q_matrix.to(torch.int8) + if scales.dtype != torch.float32: + scales = scales.to(torch.float32) + + # For Q4, ensure the values stay within 4-bit range + if bits == 4: + q_matrix = torch.clamp(q_matrix, -7, 7) + rows, cols = q_matrix.shape + dequant_matrix = q_matrix.to(torch.float32) + scales_broadcast = scales.view(1, cols) + # Apply dequantization to all columns at once using matrix multiplication + dequant_matrix = dequant_matrix * scales_broadcast + + return dequant_matrix + + + def _quantize_weight(self, matrix, bits=8): + """ + Quantize a floating-point matrix to lower precision (Q8 or Q4) + + Args: + matrix (torch.Tensor): Input matrix in floating-point format + bits (int): Quantization bits, either 8 or 4 + + Returns: + tuple: (quantized int matrix, scale factors for each column) + """ + if not isinstance(matrix, torch.Tensor): + matrix = torch.tensor(matrix, dtype=torch.float32) + + # Convert to float32 if needed + if matrix.dtype != torch.float32: + matrix = matrix.to(torch.float32) + + # Get matrix shape + rows, cols = matrix.shape + + # Determine quantization parameters based on bits + if bits == 8: + max_int = 127 + qtype = torch.int8 + elif bits == 4: + max_int = 7 + qtype = torch.int8 # We'll still use int8 storage but limit to 4-bit range, wait for native support + else: + raise ValueError("Quantization bits must be either 8 or 4") + + scales = torch.zeros(cols, dtype=torch.float32, device=matrix.device) + + # Calculate max absolute value for each column + max_abs_vals, _ = torch.max(torch.abs(matrix), dim=0) + + # Handle zero columns (avoid division by zero) + zero_cols = max_abs_vals == 0 + max_abs_vals[zero_cols] = 1.0 + + # Calculate scale factors for all columns at once + scales = max_abs_vals / max_int + + # Prepare the scales for broadcasting [1, cols] + scales_broadcast = scales.view(1, cols) + + # Apply quantization to the entire matrix at once + q_matrix = torch.round(matrix / scales_broadcast).to(qtype) + + # For Q4, clamp values to ensure they stay within 4-bit range + if bits == 4: + q_matrix = torch.clamp(q_matrix, -max_int, max_int) + + return q_matrix, scales + + def load(self, w: Union[Dict, nn.Parameter, Tuple, None] = None, device: Optional[str] = None): + if self.loaded: return + if device is None: device = self.device + if w is None: w = self.load_weight(device=device) + + if isinstance(w, nn.Parameter): + try: + weight = w.to(dtype=self.compute_dtype).view(self.out_features, self.in_features) + except: + weight = w.to(dtype=self.compute_dtype) + self.has_bias = False + elif isinstance(w, tuple): + try: + weight = w[0].to(dtype=self.compute_dtype).view(self.out_features, self.in_features) + except: + weight = w[0].to(dtype=self.compute_dtype) + self.bias = w[1].to(dtype=self.compute_dtype).to(device) + self.has_bias = True + else: + raise ValueError("Invalid weight type") + + self.weight, self.weight_scale = self._quantize_weight(weight, bits=8) + + self.weight = self.weight.to(device) + self.weight_scale = self.weight_scale.to(device) + + if self.has_bias: + self.bias = self.bias.to(device) + + self.loaded = True + + def unload(self): + self.weight = None + self.weight_scale = None + self.weight_zero_point = None + self._orig_weight = None + + if self.has_bias: + self.bias = None + + self.loaded = False + + class KLinearFP8(KLinearBase): # this kernel requires special handling for weight # Please load the weight file downloaded from KVCache.AI - marlin_q_w: torch.Tensor - marlin_s: torch.Tensor - g_idx: torch.Tensor - sort_indices: torch.Tensor has_bias: bool weight: torch.Tensor - scale_w: torch.Tensor bias: torch.Tensor def __init__( self, @@ -468,6 +636,7 @@ def unload(self): "KLinearTorch": KLinearTorch, "KLinearCPUInfer": KLinearCPUInfer, "KLinearFP8": KLinearFP8, + "KLinearQ8": KLinearQ8, } class KTransformersLinear(BaseInjectedModule, KLinearBase): diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py index 57d4bea0..bbac29a3 100644 --- a/ktransformers/operators/models.py +++ b/ktransformers/operators/models.py @@ -53,6 +53,7 @@ DeepseekV2DecoderLayer, DeepseekV2MoE, ) +from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.operators.base_operator import BaseInjectedModule @@ -649,8 +650,8 @@ def forward( if per_layer_prefill_flag: causal_mask = None else: - if os.name == 'nt' or get_compute_capability()<8: - print("for Windows or GPU before ampere, use forward_windows") + if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA: + # print("for Windows or GPU before ampere, use forward_windows") # only use mask in forward windows or can't flash attn causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions @@ -673,6 +674,7 @@ def forward( t_f = 0 for i, decoder_layer in enumerate(self.layers): + # print(f"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \n") if self.transfer_map is not None and i in self.transfer_map: prev_stream = torch.cuda.current_stream() cur_device = self.transfer_map[i] diff --git a/ktransformers/operators/triton_attention.py b/ktransformers/operators/triton_attention.py index 44375206..aafdea03 100644 --- a/ktransformers/operators/triton_attention.py +++ b/ktransformers/operators/triton_attention.py @@ -6,7 +6,7 @@ import triton import triton.language as tl - +from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor @triton.jit def tanh(x): # Tanh is just a scaled sigmoid @@ -181,8 +181,8 @@ def _decode_grouped_att_m_fwd( # [TODO] work around shmem limit on MI3xx # TODO: support hip - #if is_hip_ and Lk >= 576: - # BLOCK = 16 + if device_manager.gpu_vendor == GPUVendor.AMD and Lk >= 576: + BLOCK = 16 if Lk == 576: BLOCK_DMODEL = 512 diff --git a/ktransformers/operators/triton_attention_prefill.py b/ktransformers/operators/triton_attention_prefill.py new file mode 100644 index 00000000..a807ef35 --- /dev/null +++ b/ktransformers/operators/triton_attention_prefill.py @@ -0,0 +1,206 @@ + +# Adapted from +# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py +# which was originally adapted from +# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1 + +""" +Memory-efficient attention for prefill. +It supporst page size = 1. +""" + +# Adapted from +# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1 +import torch +import triton +import triton.language as tl + +is_cuda_available = torch.cuda.is_available() +if is_cuda_available: + CUDA_CAPABILITY = torch.cuda.get_device_capability() + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + B_Start_Loc, + B_Seqlen, + Out, + stride_qbs, + stride_qh, + stride_kbs, + stride_kh, + stride_vbs, + stride_vh, + stride_obs, + stride_oh, + kv_group_num: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, + Lk: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] + ) + off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] + off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] + + mask_d = offs_d < Lk + + q = tl.load( + Q + off_q, + mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]), + other=0.0, + ) + + k_ptrs = K + off_k + v_ptrs = V + off_v + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + end_n = ( + cur_batch_seq_len + if not IS_CAUSAL + else tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len) + ) + for start_n in range(0, block_mask * end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]), + other=0.0, + ) + # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + if IS_CAUSAL: + qk += tl.where( + (start_n + offs_n[None, :] < cur_batch_seq_len) + & (offs_m[:, None] >= (start_n + offs_n[None, :])), + 0, + float("-inf"), + ) + else: + qk += tl.where( + (start_n + offs_n[None, :]) < cur_batch_seq_len, 0, float("-inf") + ) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]), + other=0.0, + ) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] + ) + out_ptrs = Out + off_o + tl.store( + out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]) + ) + + +def context_attention_fwd( + q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True +): + """ + q, k, v: [b * s, head, head_dim] + b_start_loc: [b] + b_seq_len: [b] + out: [b * s, head, head_dim] + """ + if is_cuda_available and CUDA_CAPABILITY[0] > 8: + BLOCK = 128 + else: + BLOCK = 64 + + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + num_warps = 4 if Lk <= 64 else 8 + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, + o, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + o.stride(0), + o.stride(1), + kv_group_num=kv_group_num, + BLOCK_M=BLOCK, + BLOCK_DMODEL=triton.next_power_of_2(Lk), + BLOCK_N=BLOCK, + IS_CAUSAL=is_causal, + num_warps=num_warps, + num_stages=1, + Lk=Lk, + ) \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml index 7f3e44ea..c20973df 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml @@ -22,7 +22,7 @@ replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: - generate_device: "cuda" + generate_device: "cpu" prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" diff --git a/ktransformers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml b/ktransformers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml new file mode 100644 index 00000000..628a952e --- /dev/null +++ b/ktransformers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml @@ -0,0 +1,76 @@ +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + name: "^lm_head$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearCPUInfer" + prefill_op: "KLinearTorch" + +- match: + name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cpu" + prefill_device: "cuda" + generate_op: "KLinearQ8" + prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE + replace: + class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + class: ktransformers.models.modeling_deepseek_v3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGate + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + absorb_for_prefill: False # change this to True to enable long context(prefill may slower). +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" \ No newline at end of file diff --git a/ktransformers/tests/test_pytorch_q8.py b/ktransformers/tests/test_pytorch_q8.py new file mode 100644 index 00000000..a10b67fe --- /dev/null +++ b/ktransformers/tests/test_pytorch_q8.py @@ -0,0 +1,46 @@ +import torch + +# 定义一个包含线性层的浮点模型 +class LinearModel(torch.nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features) + + def forward(self, x): + return self.linear(x) + +# 创建浮点模型实例 +in_features = 64 +out_features = 128 +model_fp32 = LinearModel(in_features, out_features) + +# 创建量化模型实例 +model_int8 = torch.ao.quantization.quantize_dynamic( + model_fp32, # 原始浮点模型 + {torch.nn.Linear}, # 要量化的层类型集合 + dtype=torch.qint8 # 量化的目标数据类型 +) + +# 测试模型 +batch_size = 32 +input_fp32 = torch.randn(1, batch_size, in_features) # 生成随机输入数据 +output_int8 = model_int8(input_fp32) # 通过量化模型运行数据 + +# 打印输出形状验证 +print(f"输入形状: {input_fp32.shape}") +print(f"输出形状: {output_int8.shape}") + +# 比较原始模型和量化模型的输出 +with torch.no_grad(): + output_fp32 = model_fp32(input_fp32) + +print(f"FP32输出的前几个值: {output_fp32[0, :5]}") +print(f"INT8输出的前几个值: {output_int8[0, :5]}") + +# 计算平均误差 +error = torch.abs(output_fp32 - output_int8).mean().item() +print(f"平均绝对误差: {error}") + +# 打印模型类型信息 +print(f"量化前模型类型: {type(model_fp32.linear)}") +print(f"量化后模型类型: {type(model_int8.linear)}") \ No newline at end of file diff --git a/ktransformers/util/vendors.py b/ktransformers/util/vendors.py new file mode 100644 index 00000000..c9a709e2 --- /dev/null +++ b/ktransformers/util/vendors.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +from enum import IntEnum, auto +from typing import Optional, Union, List +import torch + +class GPUVendor(IntEnum): + NVIDIA = auto() + AMD = auto() + MooreThreads = auto() + MetaX = auto() + MUSA = auto() + Unknown = auto() + +class DeviceManager: + """ + Device manager that provides a unified interface for handling different GPU vendors + """ + def __init__(self): + self.gpu_vendor = self._detect_gpu_vendor() + self.available_devices = self._get_available_devices() + + def _detect_gpu_vendor(self) -> GPUVendor: + """Detect GPU vendor type""" + if not torch.cuda.is_available(): + # Check MUSA availability (assuming a musa module exists) + try: + import musa + if musa.is_available(): + return GPUVendor.MUSA + except (ImportError, AttributeError): + pass + + return GPUVendor.Unknown + + device_name = torch.cuda.get_device_name(0).lower() + + if any(name in device_name for name in ["nvidia", "geforce", "quadro", "tesla", "titan", "rtx", "gtx"]): + return GPUVendor.NVIDIA + elif any(name in device_name for name in ["amd", "radeon", "rx", "vega", "instinct", "firepro", "mi"]): + return GPUVendor.AMD + elif any(name in device_name for name in ["mthreads", "moore", "mtt"]): + return GPUVendor.MooreThreads + elif any(name in device_name for name in ["metax", "meta"]): + return GPUVendor.MetaX + elif "musa" in device_name: + return GPUVendor.MUSA + + # Backend check + try: + if hasattr(torch.version, 'hip') and torch.version.hip is not None: + return GPUVendor.AMD + elif hasattr(torch.version, 'cuda') and torch.version.cuda is not None: + return GPUVendor.NVIDIA + except: + pass + + return GPUVendor.Unknown + + def _get_available_devices(self) -> List[int]: + """Get list of available device indices""" + devices = [] + + if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD: + devices = list(range(torch.cuda.device_count())) + elif self.gpu_vendor == GPUVendor.MUSA: + try: + import musa + devices = list(range(musa.device_count())) + except (ImportError, AttributeError): + pass + + return devices + + def get_device_str(self, device_id: Union[int, str]) -> str: + """ + Get device string for the given device ID + + Args: + device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string + + Returns: + Device string representation (e.g., "cuda:0", "musa:1", "cpu") + """ + if device_id == -1 or device_id == "cpu": + return "cpu" + + if isinstance(device_id, int): + if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD: + if device_id < torch.cuda.device_count(): + return f"cuda:{device_id}" + elif self.gpu_vendor == GPUVendor.MUSA: + try: + import musa + if device_id < musa.device_count(): + return f"musa:{device_id}" + except (ImportError, AttributeError): + pass + + return "cpu" + + def to_torch_device(self, device_id: Union[int, str] = 0) -> torch.device: + """ + Convert device ID to torch.device object + + Args: + device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string + + Returns: + torch.device object + """ + device_str = self.get_device_str(device_id) + + # Handle MUSA device + if device_str.startswith("musa:"): + try: + import musa + index = int(device_str.split(":")[-1]) + return musa.device(index) + except (ImportError, ValueError, AttributeError): + return torch.device("cpu") + + # Standard PyTorch device + return torch.device(device_str) + + def move_tensor_to_device(self, tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor: + """ + Move tensor to specified device + + Args: + tensor: PyTorch tensor to move + device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string + + Returns: + Tensor moved to the specified device + """ + device = self.to_torch_device(device_id) + return tensor.to(device) + + def is_available(self, index: int = 0) -> bool: + """ + Check if device at specified index is available + + Args: + index: Device index to check + + Returns: + True if the device is available, False otherwise + """ + if index < 0: + return True # CPU is always available + + return index in self.available_devices + + def get_all_devices(self) -> List[int]: + """ + Get all available device indices + + Returns: + List of available device indices (0, 1, 2, etc.) + """ + return self.available_devices + +# Create global device manager instance +device_manager = DeviceManager() + +# Convenience functions +def get_device(device_id: Union[int, str] = 0) -> torch.device: + """ + Get torch.device object for the specified device ID + + Args: + device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string + + Returns: + torch.device object + """ + return device_manager.to_torch_device(device_id) + +def to_device(tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor: + """ + Move tensor to specified device + + Args: + tensor: PyTorch tensor to move + device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string + + Returns: + Tensor moved to the specified device + """ + return device_manager.move_tensor_to_device(tensor, device_id) + +# Get devices +cpu_device = get_device(-1) # CPU using index -1 +cpu_device2 = get_device("cpu") # CPU using string "cpu" +gpu0 = get_device(0) # First GPU + +# Move tensors +x = torch.randn(3, 3) +x_gpu = to_device(x, 0) # Move to first GPU +x_cpu1 = to_device(x, -1) # Move to CPU using index -1 +x_cpu2 = to_device(x, "cpu") # Move to CPU using string "cpu" \ No newline at end of file diff --git a/setup.py b/setup.py index ea154828..5c29b8f5 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ from wheel.bdist_wheel import bdist_wheel as _bdist_wheel from setuptools import setup, Extension from cpufeature.extension import CPUFeature -from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME +from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME try: from torch_musa.utils.simple_porting import SimplePorting from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME @@ -64,6 +64,70 @@ def get_musa_bare_metal_version(self, musa_dir): musa_version = f"{bare_metal_version.major}{bare_metal_version.minor}" return musa_version + def get_rocm_bare_metal_version(self, rocm_dir): + """ + Get the ROCm version from the ROCm installation directory. + + Args: + rocm_dir: Path to the ROCm installation directory + + Returns: + A string representation of the ROCm version (e.g., "63" for ROCm 6.3) + """ + try: + # Try using rocm_agent_enumerator to get version info + raw_output = subprocess.check_output( + [rocm_dir + "/bin/rocminfo", "--version"], + universal_newlines=True, + stderr=subprocess.STDOUT) + # Extract version number from output + match = re.search(r'(\d+\.\d+)', raw_output) + if match: + version_str = match.group(1) + version = parse(version_str) + rocm_version = f"{version.major}{version.minor}" + return rocm_version + except (subprocess.CalledProcessError, FileNotFoundError): + # If rocminfo --version fails, try alternative methods + pass + + try: + # Try reading version from release file + with open(os.path.join(rocm_dir, "share/doc/hip/version.txt"), "r") as f: + version_str = f.read().strip() + version = parse(version_str) + rocm_version = f"{version.major}{version.minor}" + return rocm_version + except (FileNotFoundError, IOError): + pass + + # If all else fails, try to extract from directory name + dir_name = os.path.basename(os.path.normpath(rocm_dir)) + match = re.search(r'rocm-(\d+\.\d+)', dir_name) + if match: + version_str = match.group(1) + version = parse(version_str) + rocm_version = f"{version.major}{version.minor}" + return rocm_version + + # Fallback to extracting from hipcc version + try: + raw_output = subprocess.check_output( + [rocm_dir + "/bin/hipcc", "--version"], + universal_newlines=True, + stderr=subprocess.STDOUT) + match = re.search(r'HIP version: (\d+\.\d+)', raw_output) + if match: + version_str = match.group(1) + version = parse(version_str) + rocm_version = f"{version.major}{version.minor}" + return rocm_version + except (subprocess.CalledProcessError, FileNotFoundError): + pass + + # If we still can't determine the version, raise an error + raise ValueError(f"Could not determine ROCm version from directory: {rocm_dir}") + def get_cuda_bare_metal_version(self, cuda_dir): raw_output = subprocess.check_output( [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) @@ -148,11 +212,13 @@ def get_package_version(self, full_version=False): cpu_instruct = self.get_cpu_instruct() backend_version = "" if CUDA_HOME is not None: - backend_version = f"cu{self.get_cuda_bare_metal_version(CUDA_HOME)}" + backend_version = f"" elif MUSA_HOME is not None: backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}" + elif ROCM_HOME is not None: + backend_version = f"rocm{self.get_rocm_bare_metal_version(ROCM_HOME)}" else: - raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.") + raise ValueError("Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set.") package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}" if full_version: return package_version @@ -247,9 +313,13 @@ def build_extension(self, ext) -> None: cmake_args += ["-DKTRANSFORMERS_USE_CUDA=ON"] elif MUSA_HOME is not None: cmake_args += ["-DKTRANSFORMERS_USE_MUSA=ON"] + elif ROCM_HOME is not None: + cmake_args += ["-DKTRANSFORMERS_USE_ROCM=ON"] else: raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.") - + # log cmake_args + print("CMake args:", cmake_args) + build_args = [] if "CMAKE_ARGS" in os.environ: cmake_args += [ @@ -328,7 +398,7 @@ def build_extension(self, ext) -> None: ["cmake", "--build", ".", "--verbose", *build_args], cwd=build_temp, check=True ) -if CUDA_HOME is not None: +if CUDA_HOME is not None or ROCM_HOME is not None: ops_module = CUDAExtension('KTransformersOps', [ 'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu', 'ktransformers/ktransformers_ext/cuda/binding.cpp', @@ -338,7 +408,7 @@ def build_extension(self, ext) -> None: 'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'], 'nvcc': [ '-O3', - '--use_fast_math', + # '--use_fast_math', '-Xcompiler', '-fPIC', '-DKTRANSFORMERS_USE_CUDA', ] @@ -371,6 +441,7 @@ def build_extension(self, ext) -> None: raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.") setup( + name=VersionInfo.PACKAGE_NAME, version=VersionInfo().get_package_version(), cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild}, ext_modules=[