diff --git a/README.md b/README.md
index 76ad6ebd..63728d21 100644
--- a/README.md
+++ b/README.md
@@ -23,6 +23,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
🔥 Updates
+* **Mar 15, 2025**: Support ROCm on AMD GPU ([Tutorial](./doc/en/ROCm.md)).
* **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and [IQ1_S/FP8 hybrid](./doc/en/fp8_kernel.md) weights. Support 139K [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context) for DeepSeek-V3 and R1 in 24GB VRAM.
* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context).
* **Feb 15, 2025**: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed (+15%, up to 16 Tokens/s), update [docs](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/).
diff --git a/doc/README.md b/doc/README.md
index 8bd94a00..f0683a30 100644
--- a/doc/README.md
+++ b/doc/README.md
@@ -22,6 +22,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
🔥 Updates
+* **Mar 15, 2025**: Support ROCm on AMD GPU ([Tutorial](./doc/en/ROCm.md)).
* **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and [IQ1_S/FP8 hybrid](./doc/en/fp8_kernel.md) weights. Support 139K [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context) for DeepSeek-V3 and R1 in 24GB VRAM.
* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context).
* **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. The detailed tutorial is [here](./en/DeepseekR1_V3_tutorial.md).
diff --git a/doc/SUMMARY.md b/doc/SUMMARY.md
index d9fa9b85..854549c5 100644
--- a/doc/SUMMARY.md
+++ b/doc/SUMMARY.md
@@ -10,6 +10,7 @@
- [Injection Tutorial](en/injection_tutorial.md)
- [Multi-GPU Tutorial](en/multi-gpu-tutorial.md)
- [Use FP8 GPU Kernel](en/fp8_kernel.md)
+- [Use AMD GPU](en/ROCm.md)
# Server
- [Server](en/api/server/server.md)
- [Website](en/api/server/website.md)
diff --git a/doc/en/ROCm.md b/doc/en/ROCm.md
new file mode 100644
index 00000000..39f48902
--- /dev/null
+++ b/doc/en/ROCm.md
@@ -0,0 +1,96 @@
+# ROCm Support for ktransformers (Beta)
+
+## Introduction
+
+### Overview
+In our effort to expand GPU architecture support beyond NVIDIA, we are excited to introduce **AMD GPU support through ROCm** in ktransformers (Beta release). This implementation has been tested and developed using EPYC 9274F processors and AMD Radeon 7900xtx GPUs.
+
+## Installation Guide
+
+### 1. Install ROCm Driver
+Begin by installing the ROCm drivers for your AMD GPU:
+- [Official ROCm Installation Guide for Radeon GPUs](https://rocm.docs.amd.com/projects/radeon/en/latest/docs/install/native_linux/install-radeon.html)
+
+### 2. Set Up Conda Environment
+We recommend using Miniconda3/Anaconda3 for environment management:
+
+```bash
+# Download Miniconda
+wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
+
+# Create environment
+conda create --name ktransformers python=3.11
+conda activate ktransformers
+
+# Install required libraries
+conda install -c conda-forge libstdcxx-ng
+
+# Verify GLIBCXX version (should include 3.4.32)
+strings ~/anaconda3/envs/ktransformers/lib/libstdc++.so.6 | grep GLIBCXX
+```
+
+> **Note:** Adjust the Anaconda path if your installation directory differs from `~/anaconda3`
+
+### 3. Install PyTorch for ROCm
+Install PyTorch with ROCm 6.2.4 support:
+
+```bash
+pip3 install torch torchvision torchaudio \
+ --index-url https://download.pytorch.org/whl/rocm6.2.4
+pip3 install packaging ninja cpufeature numpy
+```
+
+> **Tip:** For other ROCm versions, visit [PyTorch Previous Versions](https://pytorch.org/get-started/previous-versions/)
+
+### 4. Build ktransformers
+
+```bash
+# Clone repository
+git clone https://github.com/kvcache-ai/ktransformers.git
+cd ktransformers
+git submodule update --init
+
+# Optional: Compile web interface
+# See: api/server/website.md
+
+# Install dependencies
+bash install.sh
+```
+
+## Running DeepSeek-R1 Models
+
+### Configuration for 24GB VRAM GPUs
+Use our optimized configuration for constrained VRAM:
+
+```bash
+python ktransformers/local_chat.py \
+ --model_path deepseek-ai/DeepSeek-R1 \
+ --gguf_path \
+ --optimize_config_path ktransformers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml \
+ --cpu_infer
+```
+
+> **Beta Note:** Current Q8 linear implementation (Marlin alternative) shows suboptimal performance. Expect optimizations in future releases.
+
+### Configuration for 40GB+ VRAM GPUs
+For better performance on high-VRAM GPUs:
+
+1. Modify `DeepSeek-V3-Chat.yaml`:
+ ```yaml
+ # Replace all instances of:
+ KLinearMarlin → KLinearTorch
+ ```
+
+2. Execute with:
+ ```bash
+ python ktransformers/local_chat.py \
+ --model_path deepseek-ai/DeepSeek-R1 \
+ --gguf_path \
+ --optimize_config_path \
+ --cpu_infer
+ ```
+> **Tip:** If you got 2 * 24GB AMD GPUS, you may also do the same modify and run `ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml` instead.
+
+## Known Limitations
+- Marlin operations not supported on ROCm platform
+- Current Q8 linear implementation shows reduced performance (Beta limitation)
diff --git a/ktransformers/__init__.py b/ktransformers/__init__.py
index b100dcb1..a8bce455 100644
--- a/ktransformers/__init__.py
+++ b/ktransformers/__init__.py
@@ -8,4 +8,4 @@
LastEditors : chenxl
LastEditTime : 2025-02-15 03:53:02
'''
-__version__ = "0.2.3.post1"
+__version__ = "0.2.3post2"
diff --git a/ktransformers/ktransformers_ext/CMakeLists.txt b/ktransformers/ktransformers_ext/CMakeLists.txt
index 08e578db..eefcadf0 100644
--- a/ktransformers/ktransformers_ext/CMakeLists.txt
+++ b/ktransformers/ktransformers_ext/CMakeLists.txt
@@ -32,6 +32,7 @@ endif()
option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF)
option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" OFF)
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
+option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF)
# Architecture specific
# TODO: probably these flags need to be tweaked on some architectures
@@ -201,6 +202,31 @@ endif()
# message(STATUS "Can't found CUDA lib")
# endif()
+if (NOT EXISTS $ENV{ROCM_PATH})
+ if (NOT EXISTS /opt/rocm)
+ set(ROCM_PATH /usr)
+ else()
+ set(ROCM_PATH /opt/rocm)
+ endif()
+else()
+ set(ROCM_PATH $ENV{ROCM_PATH})
+endif()
+
+list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
+list(APPEND CMAKE_PREFIX_PATH "${ROCM_PATH}/lib64/cmake")
+
+if (NOT EXISTS $ENV{MUSA_PATH})
+ if (NOT EXISTS /opt/musa)
+ set(MUSA_PATH /usr/local/musa)
+ else()
+ set(MUSA_PATH /opt/musa)
+ endif()
+else()
+ set(MUSA_PATH $ENV{MUSA_PATH})
+endif()
+
+list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake")
+
add_compile_options("$<$:${ARCH_FLAGS}>")
add_compile_options("$<$:${ARCH_FLAGS}>")
@@ -218,6 +244,14 @@ elseif (UNIX)
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
endif()
+ if (KTRANSFORMERS_USE_ROCM)
+ find_package(HIP REQUIRED)
+ if(HIP_FOUND)
+ include_directories("${HIP_INCLUDE_DIRS}")
+ add_compile_definitions(KTRANSFORMERS_USE_ROCM=1)
+ endif()
+ endif()
+
if (KTRANSFORMERS_USE_MUSA)
if (NOT EXISTS $ENV{MUSA_PATH})
if (NOT EXISTS /opt/musa)
@@ -258,6 +292,11 @@ elseif(UNIX)
endif()
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
endif()
+ if (KTRANSFORMERS_USE_ROCM)
+ add_compile_definitions(USE_HIP=1)
+ target_link_libraries(${PROJECT_NAME} PRIVATE "${ROCM_PATH}/lib/libamdhip64.so")
+ message(STATUS "Building for HIP")
+ endif()
if(KTRANSFORMERS_USE_MUSA)
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
endif()
diff --git a/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h b/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h
index d0f0c603..d0f7b11b 100644
--- a/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h
+++ b/ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h
@@ -7,79 +7,83 @@
* @LastEditTime : 2024-08-07 09:47:43
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
-#ifndef CPUINFER_CPUINFER_H
-#define CPUINFER_CPUINFER_H
-
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#ifdef KTRANSFORMERS_USE_CUDA
-#include "vendors/cuda.h"
-#elif KTRANSFORMERS_USE_MUSA
-#include "vendors/musa.h"
-#endif
-
-#include "backend.h"
-#include "task_queue.h"
-
-#include "llama.cpp/ggml-impl.h"
-
-class CPUInfer {
- public:
- CPUInfer(int thread_num) {
- backend_ = new Backend(thread_num - 1);
- task_queue_ = new TaskQueue();
- for (int i = 0; i < (1 << 16); ++i) {
- ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(i);
- }
- }
-
- ~CPUInfer() {
- delete backend_;
- delete task_queue_;
- }
-
- template
- void enqueue(Func f, Obj* obj, Args... args) {
- task_queue_->enqueue([=]() {
- std::invoke(f, *obj, args..., backend_);
- });
- }
-
- void submit(std::pair params) {
- void (*func)(void*) = (void (*)(void*))params.first;
- void* args = (void*)params.second;
- *((CPUInfer**)args) = this;
- func(args);
- }
-
- void sync() {
- task_queue_->sync();
- }
-
- void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair params) {
- void (*func)(void*) = (void (*)(void*))params.first;
- void* args = (void*)params.second;
- *((CPUInfer**)args) = this;
- cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args);
- }
-
- static void sync_(void* cpu_infer_ptr) {
- CPUInfer* cpuinfer = (CPUInfer*)cpu_infer_ptr;
- cpuinfer->sync();
- }
-
- void sync_with_cuda_stream(intptr_t user_cuda_stream) {
- cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this);
- }
-
- public:
- Backend* backend_;
- TaskQueue* task_queue_;
-};
-
-#endif
\ No newline at end of file
+ #ifndef CPUINFER_CPUINFER_H
+ #define CPUINFER_CPUINFER_H
+
+ #include
+ #include
+ #include
+ #include
+ #include
+ #include
+ #include
+ #ifdef KTRANSFORMERS_USE_CUDA
+ #include "vendors/cuda.h"
+ #elif KTRANSFORMERS_USE_MUSA
+ #include "vendors/musa.h"
+ #elif KTRANSFORMERS_USE_ROCM
+ #define __HIP_PLATFORM_AMD__
+ #include "vendors/hip.h"
+ #endif
+
+ #include "backend.h"
+ #include "task_queue.h"
+ #include "../vendors/vendor.h"
+
+ #include "llama.cpp/ggml-impl.h"
+
+ class CPUInfer {
+ public:
+ CPUInfer(int thread_num) {
+ backend_ = new Backend(thread_num - 1);
+ task_queue_ = new TaskQueue();
+ for (int i = 0; i < (1 << 16); ++i) {
+ ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(i);
+ }
+ }
+
+ ~CPUInfer() {
+ delete backend_;
+ delete task_queue_;
+ }
+
+ template
+ void enqueue(Func f, Obj* obj, Args... args) {
+ task_queue_->enqueue([=]() {
+ std::invoke(f, *obj, args..., backend_);
+ });
+ }
+
+ void submit(std::pair params) {
+ void (*func)(void*) = (void (*)(void*))params.first;
+ void* args = (void*)params.second;
+ *((CPUInfer**)args) = this;
+ func(args);
+ }
+
+ void sync() {
+ task_queue_->sync();
+ }
+
+ void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair params) {
+ void (*func)(void*) = (void (*)(void*))params.first;
+ void* args = (void*)params.second;
+ *((CPUInfer**)args) = this;
+ cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args);
+ }
+
+ static void sync_(void* cpu_infer_ptr) {
+ CPUInfer* cpuinfer = (CPUInfer*)cpu_infer_ptr;
+ cpuinfer->sync();
+ }
+
+ void sync_with_cuda_stream(intptr_t user_cuda_stream) {
+ cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this);
+ }
+
+ public:
+ Backend* backend_;
+ TaskQueue* task_queue_;
+ };
+
+ #endif
\ No newline at end of file
diff --git a/ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h b/ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h
index 082ad2c3..1746b073 100644
--- a/ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h
+++ b/ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h
@@ -1,3 +1,15 @@
#pragma once
-#include
\ No newline at end of file
+#include
+#include
+#include
+#include
+#include
+
+#if CUDART_VERSION < 11020
+#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
+#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
+#define CUBLAS_COMPUTE_16F CUDA_R_16F
+#define CUBLAS_COMPUTE_32F CUDA_R_32F
+#define cublasComputeType_t cudaDataType_t
+#endif // CUDART_VERSION < 11020
diff --git a/ktransformers/ktransformers_ext/cpu_backend/vendors/hip.h b/ktransformers/ktransformers_ext/cpu_backend/vendors/hip.h
new file mode 100644
index 00000000..abbc1e89
--- /dev/null
+++ b/ktransformers/ktransformers_ext/cpu_backend/vendors/hip.h
@@ -0,0 +1,172 @@
+#pragma once
+
+#define HIP_ENABLE_WARP_SYNC_BUILTINS 1
+#include
+#include
+#include
+#include
+#ifdef __HIP_PLATFORM_AMD__
+// for rocblas_initialize()
+#include "rocblas/rocblas.h"
+#endif // __HIP_PLATFORM_AMD__
+
+#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
+#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
+#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
+#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
+#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
+#define CUBLAS_OP_N HIPBLAS_OP_N
+#define CUBLAS_OP_T HIPBLAS_OP_T
+#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
+#define CUBLAS_TF32_TENSOR_OP_MATH 0
+#define CUDA_R_16F HIPBLAS_R_16F
+#define CUDA_R_32F HIPBLAS_R_32F
+#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
+#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
+#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
+#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
+#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
+#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
+#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
+#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
+#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
+#define cublasCreate hipblasCreate
+#define cublasDestroy hipblasDestroy
+#define cublasGemmEx hipblasGemmEx
+#define cublasGemmBatchedEx hipblasGemmBatchedEx
+#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
+#define cublasHandle_t hipblasHandle_t
+#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
+#define cublasSetStream hipblasSetStream
+#define cublasSgemm hipblasSgemm
+#define cublasStatus_t hipblasStatus_t
+#define cublasOperation_t hipblasOperation_t
+#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
+#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
+#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
+#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
+#define cudaDeviceProp hipDeviceProp_t
+#define cudaDeviceSynchronize hipDeviceSynchronize
+#define cudaError_t hipError_t
+#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
+#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
+#define cudaEventCreateWithFlags hipEventCreateWithFlags
+#define cudaEventDisableTiming hipEventDisableTiming
+#define cudaEventRecord hipEventRecord
+#define cudaEventSynchronize hipEventSynchronize
+#define cudaEvent_t hipEvent_t
+#define cudaEventDestroy hipEventDestroy
+#define cudaFree hipFree
+#define cudaFreeHost hipHostFree
+#define cudaGetDevice hipGetDevice
+#define cudaGetDeviceCount hipGetDeviceCount
+#define cudaGetDeviceProperties hipGetDeviceProperties
+#define cudaGetErrorString hipGetErrorString
+#define cudaGetLastError hipGetLastError
+#define cudaHostRegister hipHostRegister
+#define cudaHostRegisterPortable hipHostRegisterPortable
+#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
+#define cudaHostUnregister hipHostUnregister
+#define cudaLaunchHostFunc hipLaunchHostFunc
+#define cudaMalloc hipMalloc
+#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
+#define cudaMemcpy hipMemcpy
+#define cudaMemcpyAsync hipMemcpyAsync
+#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
+#define cudaMemcpy2DAsync hipMemcpy2DAsync
+#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
+#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
+#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
+#define cudaMemcpyKind hipMemcpyKind
+#define cudaMemset hipMemset
+#define cudaMemsetAsync hipMemsetAsync
+#define cudaMemGetInfo hipMemGetInfo
+#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
+#define cudaSetDevice hipSetDevice
+#define cuDeviceGet hipDeviceGet
+#define CUdevice hipDevice_t
+#define CUdeviceptr hipDeviceptr_t
+#define cuMemUnmap hipMemUnmap
+#define CUmemAccessDesc hipMemAccessDesc
+#define cuMemAddressFree hipMemAddressFree
+#define cuMemRelease hipMemRelease
+#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t
+#define cuMemCreate hipMemCreate
+#define cuMemAddressReserve hipMemAddressReserve
+#define cuMemMap hipMemMap
+#define cuMemSetAccess hipMemSetAccess
+#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity
+#define CUmemAllocationProp hipMemAllocationProp
+#define cuDeviceGetAttribute hipDeviceGetAttribute
+#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
+#define cudaStreamDestroy hipStreamDestroy
+#define cudaStreamFireAndForget hipStreamFireAndForget
+#define cudaStreamNonBlocking hipStreamNonBlocking
+#define cudaStreamPerThread hipStreamPerThread
+#define cudaStreamSynchronize hipStreamSynchronize
+#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
+#define cudaGraphExec_t hipGraphExec_t
+#define cudaGraphNode_t hipGraphNode_t
+#define cudaKernelNodeParams hipKernelNodeParams
+#define cudaKernelNodeParams hipKernelNodeParams
+#define cudaGraphExecDestroy hipGraphExecDestroy
+#define cudaGraphLaunch hipGraphLaunch
+#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
+#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
+#define cudaGraphNodeType hipGraphNodeType
+#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
+#define cudaGraphInstantiate hipGraphInstantiate
+#define cudaStreamEndCapture hipStreamEndCapture
+#define cudaGraphDestroy hipGraphDestroy
+#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
+#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
+#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
+#define cudaGraphNodeGetType hipGraphNodeGetType
+#define cudaGraphGetNodes hipGraphGetNodes
+#define cudaGraphExecUpdate hipGraphExecUpdate
+#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
+#define cudaStreamBeginCapture hipStreamBeginCapture
+#define cudaGraph_t hipGraph_t
+#define cudaStream_t hipStream_t
+#define cudaSuccess hipSuccess
+#define cudaHostFn_t hipHostFn_t
+#define __trap() do { abort(); __builtin_unreachable(); } while(0)
+#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
+#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
+#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
+#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
+#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
+#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
+#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
+#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
+#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
+
+#define __CUDA_ARCH__ 1300
+
+#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
+#define GCN
+#endif
+
+#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
+#define CDNA
+#endif
+
+#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
+ defined(__gfx1150__) || defined(__gfx1151__)
+#define RDNA3
+#endif
+
+#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
+ defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
+#define RDNA2
+#endif
+
+#if defined(__gfx1010__) || defined(__gfx1012__)
+#define RDNA1
+#endif
+
+#ifndef __has_builtin
+ #define __has_builtin(x) 0
+#endif
+
+typedef hip_bfloat16 nv_bfloat16;
diff --git a/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h b/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h
index 18922218..6cc1b69e 100644
--- a/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h
+++ b/ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h
@@ -1,9 +1,137 @@
#pragma once
#include
+#include
+#include
#include
-
+#include
+#define CUBLAS_COMPUTE_16F CUDA_R_16F
+#define CUBLAS_COMPUTE_32F CUDA_R_32F
+#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
+#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
+#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
+#define CUBLAS_OP_N MUBLAS_OP_N
+#define CUBLAS_OP_T MUBLAS_OP_T
+#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
+#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
+#define CUDA_R_16F MUSA_R_16F
+#define CUDA_R_32F MUSA_R_32F
+#define cublasComputeType_t cudaDataType_t
+#define cublasCreate mublasCreate
+#define cublasDestroy mublasDestroy
+#define cublasGemmEx mublasGemmEx
+#define cublasGemmBatchedEx mublasGemmBatchedEx
+#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
+#define cublasHandle_t mublasHandle_t
+#define cublasSetMathMode mublasSetMathMode
+#define cublasSetStream mublasSetStream
+#define cublasSgemm mublasSgemm
+#define cublasStatus_t mublasStatus_t
+#define cublasOperation_t mublasOperation_t
+#define cublasGetStatusString mublasStatus_to_string
+#define cudaDataType_t musaDataType_t
+#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
+#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
+#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
+#define cudaDeviceProp musaDeviceProp
+#define cudaDeviceSynchronize musaDeviceSynchronize
+#define cudaError_t musaError_t
+#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
+#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
+#define cudaEventCreateWithFlags musaEventCreateWithFlags
+#define cudaEventDisableTiming musaEventDisableTiming
+#define cudaEventRecord musaEventRecord
+#define cudaEventSynchronize musaEventSynchronize
+#define cudaEvent_t musaEvent_t
+#define cudaEventDestroy musaEventDestroy
+#define cudaFree musaFree
+#define cudaFreeHost musaFreeHost
+#define cudaGetDevice musaGetDevice
+#define cudaGetDeviceCount musaGetDeviceCount
+#define cudaGetDeviceProperties musaGetDeviceProperties
+#define cudaGetErrorString musaGetErrorString
+#define cudaGetLastError musaGetLastError
+#define cudaHostRegister musaHostRegister
+#define cudaHostRegisterPortable musaHostRegisterPortable
+#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
+#define cudaHostUnregister musaHostUnregister
#define cudaLaunchHostFunc musaLaunchHostFunc
+#define cudaMalloc musaMalloc
+#define cudaMallocHost musaMallocHost
+#define cudaMallocManaged musaMallocManaged
+#define cudaMemcpy musaMemcpy
+#define cudaMemcpyAsync musaMemcpyAsync
+#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
+#define cudaMemcpy2DAsync musaMemcpy2DAsync
+#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
+#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
+#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
+#define cudaMemcpyKind musaMemcpyKind
+#define cudaMemset musaMemset
+#define cudaMemsetAsync musaMemsetAsync
+#define cudaMemGetInfo musaMemGetInfo
+#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
+#define cudaSetDevice musaSetDevice
+#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
+#define cudaStreamDestroy musaStreamDestroy
+#define cudaStreamFireAndForget musaStreamFireAndForget
+#define cudaStreamNonBlocking musaStreamNonBlocking
+#define cudaStreamPerThread musaStreamPerThread
+#define cudaStreamSynchronize musaStreamSynchronize
+#define cudaStreamWaitEvent musaStreamWaitEvent
#define cudaStream_t musaStream_t
-#define cudaHostFn_t musaHostFn_t
-#define nv_bfloat16 mt_bfloat16
\ No newline at end of file
+#define cudaSuccess musaSuccess
+
+// Additional mappings for MUSA virtual memory pool
+#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
+#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
+#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
+#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
+#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
+#define CUdevice MUdevice
+#define CUdeviceptr MUdeviceptr
+#define CUmemAccessDesc MUmemAccessDesc
+#define CUmemAllocationProp MUmemAllocationProp
+#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
+#define cuDeviceGet muDeviceGet
+#define cuDeviceGetAttribute muDeviceGetAttribute
+#define cuMemAddressFree muMemAddressFree
+#define cuMemAddressReserve muMemAddressReserve
+#define cuMemCreate muMemCreate
+#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
+#define cuMemMap muMemMap
+#define cuMemRelease muMemRelease
+#define cuMemSetAccess muMemSetAccess
+#define cuMemUnmap muMemUnmap
+#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
+#define cudaFuncSetAttribute musaFuncSetAttribute
+#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
+#define make_cudaExtent make_musaExtent
+#define make_cudaPitchedPtr make_musaPitchedPtr
+
+// Additional mappings for MUSA graphs
+#define CUDA_SUCCESS MUSA_SUCCESS
+#define CUresult MUresult
+#define cuGetErrorString muGetErrorString
+#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
+#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
+#define cudaGraphDestroy musaGraphDestroy
+#define cudaGraphExecDestroy musaGraphExecDestroy
+#define cudaGraphExec_t musaGraphExec_t
+#define cudaGraphExecUpdate musaGraphExecUpdate
+#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
+#define cudaGraphGetNodes musaGraphGetNodes
+#define cudaGraphInstantiate musaGraphInstantiate
+#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
+#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
+#define cudaGraphLaunch musaGraphLaunch
+#define cudaGraphNodeGetType musaGraphNodeGetType
+#define cudaGraphNode_t musaGraphNode_t
+#define cudaGraphNodeType musaGraphNodeType
+#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
+#define cudaGraph_t musaGraph_t
+#define cudaKernelNodeParams musaKernelNodeParams
+#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
+#define cudaStreamEndCapture musaStreamEndCapture
+
+typedef mt_bfloat16 nv_bfloat16;
diff --git a/ktransformers/ktransformers_ext/cpu_backend/vendors/vendor.h b/ktransformers/ktransformers_ext/cpu_backend/vendors/vendor.h
new file mode 100644
index 00000000..84704389
--- /dev/null
+++ b/ktransformers/ktransformers_ext/cpu_backend/vendors/vendor.h
@@ -0,0 +1,13 @@
+#ifndef CPUINFER_VENDOR_VENDOR_H
+#define CPUINFER_VENDOR_VENDOR_H
+
+#ifdef USE_CUDA
+#include "cuda.h"
+#elif USE_HIP
+#define __HIP_PLATFORM_AMD__
+#include "hip.h"
+#elif USE_MUSA
+#include "musa.h"
+#endif
+
+#endif // CPUINFER_VENDOR_VENDOR_H
\ No newline at end of file
diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu b/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
index e80efc45..23630356 100644
--- a/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
+++ b/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
@@ -16,6 +16,10 @@
#include
#include
+#ifdef KTRANSFORMERS_USE_ROCM
+typedef hip_bfloat16 nv_bfloat16;
+#endif
+
__global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){
diff --git a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu
index 54e538ae..87f4581b 100644
--- a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu
+++ b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu
@@ -36,7 +36,7 @@ inline std::string str(T x) {
namespace gptq_marlin {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined(__HIP_PLATFORM_AMD__)
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
diff --git a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh
index 66a59203..ccf9cfd8 100644
--- a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh
+++ b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh
@@ -39,7 +39,7 @@ using I4 = Vec;
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined (__HIP_PLATFORM_AMD__)
// No support for async
#else
diff --git a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh
index b8babfb0..80f6ea43 100644
--- a/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh
+++ b/ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh
@@ -8,6 +8,11 @@
#include
#include
+#ifdef __HIP_PLATFORM_AMD__
+typedef __hip_bfloat16 nv_bfloat16;
+typedef __hip_bfloat162 nv_bfloat162;
+#endif
+
namespace gptq_marlin {
template
diff --git a/ktransformers/ktransformers_ext/ext_bindings.cpp b/ktransformers/ktransformers_ext/ext_bindings.cpp
index 902d4271..0078a79e 100644
--- a/ktransformers/ktransformers_ext/ext_bindings.cpp
+++ b/ktransformers/ktransformers_ext/ext_bindings.cpp
@@ -9,7 +9,6 @@
**/
// Python bindings
#include "cpu_backend/cpuinfer.h"
-#include "device_launch_parameters.h"
#include "llamafile/flags.h"
#include "operators/kvcache/kvcache.h"
#include "operators/llamafile/linear.h"
diff --git a/ktransformers/ktransformers_ext/vendors/cuda.h b/ktransformers/ktransformers_ext/vendors/cuda.h
new file mode 100644
index 00000000..1746b073
--- /dev/null
+++ b/ktransformers/ktransformers_ext/vendors/cuda.h
@@ -0,0 +1,15 @@
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+
+#if CUDART_VERSION < 11020
+#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
+#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
+#define CUBLAS_COMPUTE_16F CUDA_R_16F
+#define CUBLAS_COMPUTE_32F CUDA_R_32F
+#define cublasComputeType_t cudaDataType_t
+#endif // CUDART_VERSION < 11020
diff --git a/ktransformers/ktransformers_ext/vendors/hip.h b/ktransformers/ktransformers_ext/vendors/hip.h
new file mode 100644
index 00000000..abbc1e89
--- /dev/null
+++ b/ktransformers/ktransformers_ext/vendors/hip.h
@@ -0,0 +1,172 @@
+#pragma once
+
+#define HIP_ENABLE_WARP_SYNC_BUILTINS 1
+#include
+#include
+#include
+#include
+#ifdef __HIP_PLATFORM_AMD__
+// for rocblas_initialize()
+#include "rocblas/rocblas.h"
+#endif // __HIP_PLATFORM_AMD__
+
+#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
+#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
+#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
+#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
+#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
+#define CUBLAS_OP_N HIPBLAS_OP_N
+#define CUBLAS_OP_T HIPBLAS_OP_T
+#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
+#define CUBLAS_TF32_TENSOR_OP_MATH 0
+#define CUDA_R_16F HIPBLAS_R_16F
+#define CUDA_R_32F HIPBLAS_R_32F
+#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
+#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
+#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
+#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
+#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
+#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
+#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
+#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
+#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
+#define cublasCreate hipblasCreate
+#define cublasDestroy hipblasDestroy
+#define cublasGemmEx hipblasGemmEx
+#define cublasGemmBatchedEx hipblasGemmBatchedEx
+#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
+#define cublasHandle_t hipblasHandle_t
+#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
+#define cublasSetStream hipblasSetStream
+#define cublasSgemm hipblasSgemm
+#define cublasStatus_t hipblasStatus_t
+#define cublasOperation_t hipblasOperation_t
+#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
+#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
+#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
+#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
+#define cudaDeviceProp hipDeviceProp_t
+#define cudaDeviceSynchronize hipDeviceSynchronize
+#define cudaError_t hipError_t
+#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
+#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
+#define cudaEventCreateWithFlags hipEventCreateWithFlags
+#define cudaEventDisableTiming hipEventDisableTiming
+#define cudaEventRecord hipEventRecord
+#define cudaEventSynchronize hipEventSynchronize
+#define cudaEvent_t hipEvent_t
+#define cudaEventDestroy hipEventDestroy
+#define cudaFree hipFree
+#define cudaFreeHost hipHostFree
+#define cudaGetDevice hipGetDevice
+#define cudaGetDeviceCount hipGetDeviceCount
+#define cudaGetDeviceProperties hipGetDeviceProperties
+#define cudaGetErrorString hipGetErrorString
+#define cudaGetLastError hipGetLastError
+#define cudaHostRegister hipHostRegister
+#define cudaHostRegisterPortable hipHostRegisterPortable
+#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
+#define cudaHostUnregister hipHostUnregister
+#define cudaLaunchHostFunc hipLaunchHostFunc
+#define cudaMalloc hipMalloc
+#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
+#define cudaMemcpy hipMemcpy
+#define cudaMemcpyAsync hipMemcpyAsync
+#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
+#define cudaMemcpy2DAsync hipMemcpy2DAsync
+#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
+#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
+#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
+#define cudaMemcpyKind hipMemcpyKind
+#define cudaMemset hipMemset
+#define cudaMemsetAsync hipMemsetAsync
+#define cudaMemGetInfo hipMemGetInfo
+#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
+#define cudaSetDevice hipSetDevice
+#define cuDeviceGet hipDeviceGet
+#define CUdevice hipDevice_t
+#define CUdeviceptr hipDeviceptr_t
+#define cuMemUnmap hipMemUnmap
+#define CUmemAccessDesc hipMemAccessDesc
+#define cuMemAddressFree hipMemAddressFree
+#define cuMemRelease hipMemRelease
+#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t
+#define cuMemCreate hipMemCreate
+#define cuMemAddressReserve hipMemAddressReserve
+#define cuMemMap hipMemMap
+#define cuMemSetAccess hipMemSetAccess
+#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity
+#define CUmemAllocationProp hipMemAllocationProp
+#define cuDeviceGetAttribute hipDeviceGetAttribute
+#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
+#define cudaStreamDestroy hipStreamDestroy
+#define cudaStreamFireAndForget hipStreamFireAndForget
+#define cudaStreamNonBlocking hipStreamNonBlocking
+#define cudaStreamPerThread hipStreamPerThread
+#define cudaStreamSynchronize hipStreamSynchronize
+#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
+#define cudaGraphExec_t hipGraphExec_t
+#define cudaGraphNode_t hipGraphNode_t
+#define cudaKernelNodeParams hipKernelNodeParams
+#define cudaKernelNodeParams hipKernelNodeParams
+#define cudaGraphExecDestroy hipGraphExecDestroy
+#define cudaGraphLaunch hipGraphLaunch
+#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
+#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
+#define cudaGraphNodeType hipGraphNodeType
+#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
+#define cudaGraphInstantiate hipGraphInstantiate
+#define cudaStreamEndCapture hipStreamEndCapture
+#define cudaGraphDestroy hipGraphDestroy
+#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
+#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
+#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
+#define cudaGraphNodeGetType hipGraphNodeGetType
+#define cudaGraphGetNodes hipGraphGetNodes
+#define cudaGraphExecUpdate hipGraphExecUpdate
+#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
+#define cudaStreamBeginCapture hipStreamBeginCapture
+#define cudaGraph_t hipGraph_t
+#define cudaStream_t hipStream_t
+#define cudaSuccess hipSuccess
+#define cudaHostFn_t hipHostFn_t
+#define __trap() do { abort(); __builtin_unreachable(); } while(0)
+#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
+#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
+#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
+#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
+#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
+#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
+#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
+#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
+#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
+
+#define __CUDA_ARCH__ 1300
+
+#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
+#define GCN
+#endif
+
+#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
+#define CDNA
+#endif
+
+#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
+ defined(__gfx1150__) || defined(__gfx1151__)
+#define RDNA3
+#endif
+
+#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
+ defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
+#define RDNA2
+#endif
+
+#if defined(__gfx1010__) || defined(__gfx1012__)
+#define RDNA1
+#endif
+
+#ifndef __has_builtin
+ #define __has_builtin(x) 0
+#endif
+
+typedef hip_bfloat16 nv_bfloat16;
diff --git a/ktransformers/ktransformers_ext/vendors/musa.h b/ktransformers/ktransformers_ext/vendors/musa.h
new file mode 100644
index 00000000..6cc1b69e
--- /dev/null
+++ b/ktransformers/ktransformers_ext/vendors/musa.h
@@ -0,0 +1,137 @@
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+#define CUBLAS_COMPUTE_16F CUDA_R_16F
+#define CUBLAS_COMPUTE_32F CUDA_R_32F
+#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
+#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
+#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
+#define CUBLAS_OP_N MUBLAS_OP_N
+#define CUBLAS_OP_T MUBLAS_OP_T
+#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
+#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
+#define CUDA_R_16F MUSA_R_16F
+#define CUDA_R_32F MUSA_R_32F
+#define cublasComputeType_t cudaDataType_t
+#define cublasCreate mublasCreate
+#define cublasDestroy mublasDestroy
+#define cublasGemmEx mublasGemmEx
+#define cublasGemmBatchedEx mublasGemmBatchedEx
+#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
+#define cublasHandle_t mublasHandle_t
+#define cublasSetMathMode mublasSetMathMode
+#define cublasSetStream mublasSetStream
+#define cublasSgemm mublasSgemm
+#define cublasStatus_t mublasStatus_t
+#define cublasOperation_t mublasOperation_t
+#define cublasGetStatusString mublasStatus_to_string
+#define cudaDataType_t musaDataType_t
+#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
+#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
+#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
+#define cudaDeviceProp musaDeviceProp
+#define cudaDeviceSynchronize musaDeviceSynchronize
+#define cudaError_t musaError_t
+#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
+#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
+#define cudaEventCreateWithFlags musaEventCreateWithFlags
+#define cudaEventDisableTiming musaEventDisableTiming
+#define cudaEventRecord musaEventRecord
+#define cudaEventSynchronize musaEventSynchronize
+#define cudaEvent_t musaEvent_t
+#define cudaEventDestroy musaEventDestroy
+#define cudaFree musaFree
+#define cudaFreeHost musaFreeHost
+#define cudaGetDevice musaGetDevice
+#define cudaGetDeviceCount musaGetDeviceCount
+#define cudaGetDeviceProperties musaGetDeviceProperties
+#define cudaGetErrorString musaGetErrorString
+#define cudaGetLastError musaGetLastError
+#define cudaHostRegister musaHostRegister
+#define cudaHostRegisterPortable musaHostRegisterPortable
+#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
+#define cudaHostUnregister musaHostUnregister
+#define cudaLaunchHostFunc musaLaunchHostFunc
+#define cudaMalloc musaMalloc
+#define cudaMallocHost musaMallocHost
+#define cudaMallocManaged musaMallocManaged
+#define cudaMemcpy musaMemcpy
+#define cudaMemcpyAsync musaMemcpyAsync
+#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
+#define cudaMemcpy2DAsync musaMemcpy2DAsync
+#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
+#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
+#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
+#define cudaMemcpyKind musaMemcpyKind
+#define cudaMemset musaMemset
+#define cudaMemsetAsync musaMemsetAsync
+#define cudaMemGetInfo musaMemGetInfo
+#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
+#define cudaSetDevice musaSetDevice
+#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
+#define cudaStreamDestroy musaStreamDestroy
+#define cudaStreamFireAndForget musaStreamFireAndForget
+#define cudaStreamNonBlocking musaStreamNonBlocking
+#define cudaStreamPerThread musaStreamPerThread
+#define cudaStreamSynchronize musaStreamSynchronize
+#define cudaStreamWaitEvent musaStreamWaitEvent
+#define cudaStream_t musaStream_t
+#define cudaSuccess musaSuccess
+
+// Additional mappings for MUSA virtual memory pool
+#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
+#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
+#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
+#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
+#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
+#define CUdevice MUdevice
+#define CUdeviceptr MUdeviceptr
+#define CUmemAccessDesc MUmemAccessDesc
+#define CUmemAllocationProp MUmemAllocationProp
+#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
+#define cuDeviceGet muDeviceGet
+#define cuDeviceGetAttribute muDeviceGetAttribute
+#define cuMemAddressFree muMemAddressFree
+#define cuMemAddressReserve muMemAddressReserve
+#define cuMemCreate muMemCreate
+#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
+#define cuMemMap muMemMap
+#define cuMemRelease muMemRelease
+#define cuMemSetAccess muMemSetAccess
+#define cuMemUnmap muMemUnmap
+#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
+#define cudaFuncSetAttribute musaFuncSetAttribute
+#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
+#define make_cudaExtent make_musaExtent
+#define make_cudaPitchedPtr make_musaPitchedPtr
+
+// Additional mappings for MUSA graphs
+#define CUDA_SUCCESS MUSA_SUCCESS
+#define CUresult MUresult
+#define cuGetErrorString muGetErrorString
+#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
+#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
+#define cudaGraphDestroy musaGraphDestroy
+#define cudaGraphExecDestroy musaGraphExecDestroy
+#define cudaGraphExec_t musaGraphExec_t
+#define cudaGraphExecUpdate musaGraphExecUpdate
+#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
+#define cudaGraphGetNodes musaGraphGetNodes
+#define cudaGraphInstantiate musaGraphInstantiate
+#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
+#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
+#define cudaGraphLaunch musaGraphLaunch
+#define cudaGraphNodeGetType musaGraphNodeGetType
+#define cudaGraphNode_t musaGraphNode_t
+#define cudaGraphNodeType musaGraphNodeType
+#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
+#define cudaGraph_t musaGraph_t
+#define cudaKernelNodeParams musaKernelNodeParams
+#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
+#define cudaStreamEndCapture musaStreamEndCapture
+
+typedef mt_bfloat16 nv_bfloat16;
diff --git a/ktransformers/ktransformers_ext/vendors/vendor.h b/ktransformers/ktransformers_ext/vendors/vendor.h
new file mode 100644
index 00000000..84704389
--- /dev/null
+++ b/ktransformers/ktransformers_ext/vendors/vendor.h
@@ -0,0 +1,13 @@
+#ifndef CPUINFER_VENDOR_VENDOR_H
+#define CPUINFER_VENDOR_VENDOR_H
+
+#ifdef USE_CUDA
+#include "cuda.h"
+#elif USE_HIP
+#define __HIP_PLATFORM_AMD__
+#include "hip.h"
+#elif USE_MUSA
+#include "musa.h"
+#endif
+
+#endif // CPUINFER_VENDOR_VENDOR_H
\ No newline at end of file
diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py
index dc5747fa..1a5f55f9 100644
--- a/ktransformers/local_chat.py
+++ b/ktransformers/local_chat.py
@@ -31,6 +31,7 @@
from ktransformers.util.utils import prefill_and_generate, get_compute_capability
from ktransformers.server.config.config import Config
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
+from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
@@ -169,7 +170,7 @@ def local_chat(
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
"please change max_seq_len in ~/.ktransformers/config.yaml"
- if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8:
+ if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA:
generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size,
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py
index a9bbea6f..2c4197ab 100644
--- a/ktransformers/operators/attention.py
+++ b/ktransformers/operators/attention.py
@@ -20,8 +20,14 @@
import logging
from transformers.configuration_utils import PretrainedConfig
from transformers.cache_utils import Cache
-from flash_attn import flash_attn_func
-from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
+from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
+
+try:
+ from flash_attn import flash_attn_func
+except:
+ pass
+from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
+from ktransformers.operators.triton_attention_prefill import context_attention_fwd
import os
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
if flashinfer_enabled:
@@ -319,18 +325,27 @@ def forward_linux_triton(
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
- value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
- attn_output = flash_attn_func(
- query_states,
- key_states,
- value_states_padded,
- softmax_scale=self.softmax_scale,
- causal=True,
+ # for bsz = 1
+ attn_output = torch.zeros(bsz * q_len, self.num_heads, self.v_head_dim, device=hidden_states.device)
+ b_start_loc = torch.zeros(bsz, dtype=torch.int64, device=hidden_states.device)
+ b_seq_len = torch.full((bsz,), q_len, dtype=torch.int64, device=hidden_states.device)
+
+ max_input_len = q_len
+
+ context_attention_fwd(
+ q=query_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
+ k=key_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
+ v=value_states.squeeze(0).view(-1, self.num_heads, self.v_head_dim),
+ o=attn_output,
+ b_start_loc=b_start_loc,
+ b_seq_len=b_seq_len,
+ max_input_len=max_input_len,
+ is_causal=True
)
if self.q_head_dim != self.v_head_dim:
- attn_output = attn_output[:, :, :, : self.v_head_dim]
+ attn_output = attn_output[:, :, : self.v_head_dim]
attn_output = attn_output.reshape(
bsz, q_len, self.num_heads * self.v_head_dim
@@ -589,8 +604,7 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- if os.name == 'nt' or get_compute_capability()<8:
- print("for Windows or GPU before ampere, use forward_windows")
+ if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
return self.forward_windows(
hidden_states,
attention_mask,
diff --git a/ktransformers/operators/dynamic_attention.py b/ktransformers/operators/dynamic_attention.py
index 2d8b1efa..f64e374a 100644
--- a/ktransformers/operators/dynamic_attention.py
+++ b/ktransformers/operators/dynamic_attention.py
@@ -17,7 +17,10 @@
logger = logging.getLogger("dynamic_attention")
sys.path.append(os.path.dirname(__file__) + "/../ktransformers_ext/cpu_backend")
from ktransformers.operators.cpuinfer import CPUInfer, CPUInferKVCache
-from flash_attn import flash_attn_func, flash_attn_with_kvcache
+try:
+ from flash_attn import flash_attn_func, flash_attn_with_kvcache
+except:
+ print("falsh attn not found")
import math
diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py
index ce7e370e..c0dfc851 100644
--- a/ktransformers/operators/linear.py
+++ b/ktransformers/operators/linear.py
@@ -35,6 +35,8 @@
import cpuinfer_ext
from ktransformers.operators.cpuinfer import CPUInfer
from ktransformers.server.config.config import Config
+from typing import Dict, Tuple, Optional, Union
+import numpy as np
#class KLinearBase(BaseInjectedModule, ABC):
class KLinearBase(ABC):
@@ -176,16 +178,182 @@ def unload(self):
if self.has_bias:
self.bias = None
+
+class KLinearQ8(KLinearBase):
+ def __init__(
+ self,
+ key: str,
+ gguf_loader: GGUFLoader,
+ config: PretrainedConfig,
+ orig_module: nn.Module = None,
+ device: str = "cuda",
+ **kwargs,
+ ):
+ super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
+ self.has_bias = False
+ self.compute_dtype = torch.float32
+ self.weight = None
+ self.weight_scale = None
+ self.weight_zero_point = None
+ self.bias = None
+ self.loaded = False
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ orig_dtype = x.dtype
+ out_device = x.device
+
+ x = x.to(device=self.device, dtype=self.compute_dtype)
+
+ # 使用原始权重做矩阵乘法,模拟原始行为
+
+ # 反量化权重进行矩阵乘法
+ weight_dequant = self._dequantize_weight(self.weight, self.weight_scale, bits=8)
+ out = x @ weight_dequant.T
+
+ if self.has_bias:
+ out = out + self.bias
+
+ return out.to(dtype=orig_dtype, device=out_device)
+
+ def _dequantize_weight(self, q_matrix, scales, bits=8):
+ """
+ Dequantize a low-precision matrix back to floating-point
+
+ Args:
+ q_matrix (torch.Tensor): Quantized int matrix
+ scales (torch.Tensor): Scale factors for each column
+ bits (int): Quantization bits used (8 or 4)
+
+ Returns:
+ torch.Tensor: Dequantized floating-point matrix
+ """
+ # Ensure inputs are torch tensors
+ if not isinstance(q_matrix, torch.Tensor):
+ q_matrix = torch.tensor(q_matrix, dtype=torch.int8)
+ if not isinstance(scales, torch.Tensor):
+ scales = torch.tensor(scales, dtype=torch.float32)
+
+ # Convert to correct dtype if needed
+ if q_matrix.dtype != torch.int8:
+ q_matrix = q_matrix.to(torch.int8)
+ if scales.dtype != torch.float32:
+ scales = scales.to(torch.float32)
+
+ # For Q4, ensure the values stay within 4-bit range
+ if bits == 4:
+ q_matrix = torch.clamp(q_matrix, -7, 7)
+ rows, cols = q_matrix.shape
+ dequant_matrix = q_matrix.to(torch.float32)
+ scales_broadcast = scales.view(1, cols)
+ # Apply dequantization to all columns at once using matrix multiplication
+ dequant_matrix = dequant_matrix * scales_broadcast
+
+ return dequant_matrix
+
+
+ def _quantize_weight(self, matrix, bits=8):
+ """
+ Quantize a floating-point matrix to lower precision (Q8 or Q4)
+
+ Args:
+ matrix (torch.Tensor): Input matrix in floating-point format
+ bits (int): Quantization bits, either 8 or 4
+
+ Returns:
+ tuple: (quantized int matrix, scale factors for each column)
+ """
+ if not isinstance(matrix, torch.Tensor):
+ matrix = torch.tensor(matrix, dtype=torch.float32)
+
+ # Convert to float32 if needed
+ if matrix.dtype != torch.float32:
+ matrix = matrix.to(torch.float32)
+
+ # Get matrix shape
+ rows, cols = matrix.shape
+
+ # Determine quantization parameters based on bits
+ if bits == 8:
+ max_int = 127
+ qtype = torch.int8
+ elif bits == 4:
+ max_int = 7
+ qtype = torch.int8 # We'll still use int8 storage but limit to 4-bit range, wait for native support
+ else:
+ raise ValueError("Quantization bits must be either 8 or 4")
+
+ scales = torch.zeros(cols, dtype=torch.float32, device=matrix.device)
+
+ # Calculate max absolute value for each column
+ max_abs_vals, _ = torch.max(torch.abs(matrix), dim=0)
+
+ # Handle zero columns (avoid division by zero)
+ zero_cols = max_abs_vals == 0
+ max_abs_vals[zero_cols] = 1.0
+
+ # Calculate scale factors for all columns at once
+ scales = max_abs_vals / max_int
+
+ # Prepare the scales for broadcasting [1, cols]
+ scales_broadcast = scales.view(1, cols)
+
+ # Apply quantization to the entire matrix at once
+ q_matrix = torch.round(matrix / scales_broadcast).to(qtype)
+
+ # For Q4, clamp values to ensure they stay within 4-bit range
+ if bits == 4:
+ q_matrix = torch.clamp(q_matrix, -max_int, max_int)
+
+ return q_matrix, scales
+
+ def load(self, w: Union[Dict, nn.Parameter, Tuple, None] = None, device: Optional[str] = None):
+ if self.loaded: return
+ if device is None: device = self.device
+ if w is None: w = self.load_weight(device=device)
+
+ if isinstance(w, nn.Parameter):
+ try:
+ weight = w.to(dtype=self.compute_dtype).view(self.out_features, self.in_features)
+ except:
+ weight = w.to(dtype=self.compute_dtype)
+ self.has_bias = False
+ elif isinstance(w, tuple):
+ try:
+ weight = w[0].to(dtype=self.compute_dtype).view(self.out_features, self.in_features)
+ except:
+ weight = w[0].to(dtype=self.compute_dtype)
+ self.bias = w[1].to(dtype=self.compute_dtype).to(device)
+ self.has_bias = True
+ else:
+ raise ValueError("Invalid weight type")
+
+ self.weight, self.weight_scale = self._quantize_weight(weight, bits=8)
+
+ self.weight = self.weight.to(device)
+ self.weight_scale = self.weight_scale.to(device)
+
+ if self.has_bias:
+ self.bias = self.bias.to(device)
+
+ self.loaded = True
+
+ def unload(self):
+ self.weight = None
+ self.weight_scale = None
+ self.weight_zero_point = None
+ self._orig_weight = None
+
+ if self.has_bias:
+ self.bias = None
+
+ self.loaded = False
+
+
class KLinearFP8(KLinearBase):
# this kernel requires special handling for weight
# Please load the weight file downloaded from KVCache.AI
- marlin_q_w: torch.Tensor
- marlin_s: torch.Tensor
- g_idx: torch.Tensor
- sort_indices: torch.Tensor
has_bias: bool
weight: torch.Tensor
- scale_w: torch.Tensor
bias: torch.Tensor
def __init__(
self,
@@ -468,6 +636,7 @@ def unload(self):
"KLinearTorch": KLinearTorch,
"KLinearCPUInfer": KLinearCPUInfer,
"KLinearFP8": KLinearFP8,
+ "KLinearQ8": KLinearQ8,
}
class KTransformersLinear(BaseInjectedModule, KLinearBase):
diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py
index 57d4bea0..bbac29a3 100644
--- a/ktransformers/operators/models.py
+++ b/ktransformers/operators/models.py
@@ -53,6 +53,7 @@
DeepseekV2DecoderLayer,
DeepseekV2MoE,
)
+from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.operators.base_operator import BaseInjectedModule
@@ -649,8 +650,8 @@ def forward(
if per_layer_prefill_flag:
causal_mask = None
else:
- if os.name == 'nt' or get_compute_capability()<8:
- print("for Windows or GPU before ampere, use forward_windows")
+ if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
+ # print("for Windows or GPU before ampere, use forward_windows")
# only use mask in forward windows or can't flash attn
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
@@ -673,6 +674,7 @@ def forward(
t_f = 0
for i, decoder_layer in enumerate(self.layers):
+ # print(f"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \n")
if self.transfer_map is not None and i in self.transfer_map:
prev_stream = torch.cuda.current_stream()
cur_device = self.transfer_map[i]
diff --git a/ktransformers/operators/triton_attention.py b/ktransformers/operators/triton_attention.py
index 44375206..aafdea03 100644
--- a/ktransformers/operators/triton_attention.py
+++ b/ktransformers/operators/triton_attention.py
@@ -6,7 +6,7 @@
import triton
import triton.language as tl
-
+from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
@triton.jit
def tanh(x):
# Tanh is just a scaled sigmoid
@@ -181,8 +181,8 @@ def _decode_grouped_att_m_fwd(
# [TODO] work around shmem limit on MI3xx
# TODO: support hip
- #if is_hip_ and Lk >= 576:
- # BLOCK = 16
+ if device_manager.gpu_vendor == GPUVendor.AMD and Lk >= 576:
+ BLOCK = 16
if Lk == 576:
BLOCK_DMODEL = 512
diff --git a/ktransformers/operators/triton_attention_prefill.py b/ktransformers/operators/triton_attention_prefill.py
new file mode 100644
index 00000000..a807ef35
--- /dev/null
+++ b/ktransformers/operators/triton_attention_prefill.py
@@ -0,0 +1,206 @@
+
+# Adapted from
+# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py
+# which was originally adapted from
+# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
+
+"""
+Memory-efficient attention for prefill.
+It supporst page size = 1.
+"""
+
+# Adapted from
+# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
+import torch
+import triton
+import triton.language as tl
+
+is_cuda_available = torch.cuda.is_available()
+if is_cuda_available:
+ CUDA_CAPABILITY = torch.cuda.get_device_capability()
+
+
+@triton.jit
+def _fwd_kernel(
+ Q,
+ K,
+ V,
+ sm_scale,
+ B_Start_Loc,
+ B_Seqlen,
+ Out,
+ stride_qbs,
+ stride_qh,
+ stride_kbs,
+ stride_kh,
+ stride_vbs,
+ stride_vh,
+ stride_obs,
+ stride_oh,
+ kv_group_num: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_DMODEL: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ IS_CAUSAL: tl.constexpr,
+ Lk: tl.constexpr,
+):
+ cur_batch = tl.program_id(0)
+ cur_head = tl.program_id(1)
+ start_m = tl.program_id(2)
+
+ cur_kv_head = cur_head // kv_group_num
+
+ cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
+ cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
+
+ block_start_loc = BLOCK_M * start_m
+
+ # initialize offsets
+ offs_n = tl.arange(0, BLOCK_N)
+ offs_d = tl.arange(0, BLOCK_DMODEL)
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ off_q = (
+ (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
+ + cur_head * stride_qh
+ + offs_d[None, :]
+ )
+ off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]
+ off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]
+
+ mask_d = offs_d < Lk
+
+ q = tl.load(
+ Q + off_q,
+ mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]),
+ other=0.0,
+ )
+
+ k_ptrs = K + off_k
+ v_ptrs = V + off_v
+
+ # initialize pointer to m and l
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
+
+ block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
+
+ end_n = (
+ cur_batch_seq_len
+ if not IS_CAUSAL
+ else tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len)
+ )
+ for start_n in range(0, block_mask * end_n, BLOCK_N):
+ start_n = tl.multiple_of(start_n, BLOCK_N)
+ # -- compute qk ----
+ k = tl.load(
+ k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
+ mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]),
+ other=0.0,
+ )
+ # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
+
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
+ qk += tl.dot(q, k)
+ qk *= sm_scale
+
+ if IS_CAUSAL:
+ qk += tl.where(
+ (start_n + offs_n[None, :] < cur_batch_seq_len)
+ & (offs_m[:, None] >= (start_n + offs_n[None, :])),
+ 0,
+ float("-inf"),
+ )
+ else:
+ qk += tl.where(
+ (start_n + offs_n[None, :]) < cur_batch_seq_len, 0, float("-inf")
+ )
+
+ # -- compute m_ij, p, l_ij
+ m_ij = tl.max(qk, 1)
+ p = tl.exp(qk - m_ij[:, None])
+ l_ij = tl.sum(p, 1)
+ # -- update m_i and l_i
+ m_i_new = tl.maximum(m_i, m_ij)
+ alpha = tl.exp(m_i - m_i_new)
+ beta = tl.exp(m_ij - m_i_new)
+ l_i_new = alpha * l_i + beta * l_ij
+ # -- update output accumulator --
+ # scale p
+ p_scale = beta / l_i_new
+ p = p * p_scale[:, None]
+ # scale acc
+ acc_scale = l_i / l_i_new * alpha
+ acc = acc * acc_scale[:, None]
+ # update acc
+ v = tl.load(
+ v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
+ mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]),
+ other=0.0,
+ )
+
+ p = p.to(v.dtype)
+ acc += tl.dot(p, v)
+ # update m_i and l_i
+ l_i = l_i_new
+ m_i = m_i_new
+ # initialize pointers to output
+ off_o = (
+ (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
+ + cur_head * stride_oh
+ + offs_d[None, :]
+ )
+ out_ptrs = Out + off_o
+ tl.store(
+ out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :])
+ )
+
+
+def context_attention_fwd(
+ q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
+):
+ """
+ q, k, v: [b * s, head, head_dim]
+ b_start_loc: [b]
+ b_seq_len: [b]
+ out: [b * s, head, head_dim]
+ """
+ if is_cuda_available and CUDA_CAPABILITY[0] > 8:
+ BLOCK = 128
+ else:
+ BLOCK = 64
+
+ Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
+
+ sm_scale = 1.0 / (Lq**0.5)
+ batch, head = b_seq_len.shape[0], q.shape[1]
+ kv_group_num = q.shape[1] // k.shape[1]
+
+ grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
+ num_warps = 4 if Lk <= 64 else 8
+
+ _fwd_kernel[grid](
+ q,
+ k,
+ v,
+ sm_scale,
+ b_start_loc,
+ b_seq_len,
+ o,
+ q.stride(0),
+ q.stride(1),
+ k.stride(0),
+ k.stride(1),
+ v.stride(0),
+ v.stride(1),
+ o.stride(0),
+ o.stride(1),
+ kv_group_num=kv_group_num,
+ BLOCK_M=BLOCK,
+ BLOCK_DMODEL=triton.next_power_of_2(Lk),
+ BLOCK_N=BLOCK,
+ IS_CAUSAL=is_causal,
+ num_warps=num_warps,
+ num_stages=1,
+ Lk=Lk,
+ )
\ No newline at end of file
diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
index 7f3e44ea..c20973df 100644
--- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
+++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
@@ -22,7 +22,7 @@
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
- generate_device: "cuda"
+ generate_device: "cpu"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
diff --git a/ktransformers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml b/ktransformers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml
new file mode 100644
index 00000000..628a952e
--- /dev/null
+++ b/ktransformers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml
@@ -0,0 +1,76 @@
+- match:
+ class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
+ replace:
+ class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+
+- match:
+ name: "^lm_head$" # regular expression
+ class: torch.nn.Linear # only match modules matching name and class simultaneously
+ replace:
+ class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+ generate_op: "KLinearCPUInfer"
+ prefill_op: "KLinearTorch"
+
+- match:
+ name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
+ class: torch.nn.Linear # only match modules matching name and class simultaneously
+ replace:
+ class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
+ kwargs:
+ generate_device: "cpu"
+ prefill_device: "cuda"
+ generate_op: "KLinearQ8"
+ prefill_op: "KLinearTorch"
+- match:
+ name: "^model\\.layers\\..*\\.mlp$"
+ class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
+ replace:
+ class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+- match:
+ class: ktransformers.models.modeling_deepseek_v3.MoEGate
+ replace:
+ class: ktransformers.operators.gate.KMoEGate
+ kwargs:
+ generate_device: "cuda:0"
+ prefill_device: "cuda:0"
+- match:
+ name: "^model\\.layers\\..*\\.mlp\\.experts$"
+ replace:
+ class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
+ kwargs:
+ prefill_device: "cuda"
+ prefill_op: "KExpertsTorch"
+ generate_device: "cpu"
+ generate_op: "KExpertsCPU"
+ out_device: "cuda"
+ recursive: False # don't recursively inject submodules of this module
+- match:
+ name: "^model\\.layers\\..*\\.self_attn$"
+ replace:
+ class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+ absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
+- match:
+ name: "^model$"
+ replace:
+ class: "ktransformers.operators.models.KDeepseekV2Model"
+ kwargs:
+ per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
+- match:
+ name: "^model.embed_tokens"
+ replace:
+ class: "default"
+ kwargs:
+ generate_device: "cpu"
+ prefill_device: "cpu"
\ No newline at end of file
diff --git a/ktransformers/tests/test_pytorch_q8.py b/ktransformers/tests/test_pytorch_q8.py
new file mode 100644
index 00000000..a10b67fe
--- /dev/null
+++ b/ktransformers/tests/test_pytorch_q8.py
@@ -0,0 +1,46 @@
+import torch
+
+# 定义一个包含线性层的浮点模型
+class LinearModel(torch.nn.Module):
+ def __init__(self, in_features, out_features):
+ super().__init__()
+ self.linear = torch.nn.Linear(in_features, out_features)
+
+ def forward(self, x):
+ return self.linear(x)
+
+# 创建浮点模型实例
+in_features = 64
+out_features = 128
+model_fp32 = LinearModel(in_features, out_features)
+
+# 创建量化模型实例
+model_int8 = torch.ao.quantization.quantize_dynamic(
+ model_fp32, # 原始浮点模型
+ {torch.nn.Linear}, # 要量化的层类型集合
+ dtype=torch.qint8 # 量化的目标数据类型
+)
+
+# 测试模型
+batch_size = 32
+input_fp32 = torch.randn(1, batch_size, in_features) # 生成随机输入数据
+output_int8 = model_int8(input_fp32) # 通过量化模型运行数据
+
+# 打印输出形状验证
+print(f"输入形状: {input_fp32.shape}")
+print(f"输出形状: {output_int8.shape}")
+
+# 比较原始模型和量化模型的输出
+with torch.no_grad():
+ output_fp32 = model_fp32(input_fp32)
+
+print(f"FP32输出的前几个值: {output_fp32[0, :5]}")
+print(f"INT8输出的前几个值: {output_int8[0, :5]}")
+
+# 计算平均误差
+error = torch.abs(output_fp32 - output_int8).mean().item()
+print(f"平均绝对误差: {error}")
+
+# 打印模型类型信息
+print(f"量化前模型类型: {type(model_fp32.linear)}")
+print(f"量化后模型类型: {type(model_int8.linear)}")
\ No newline at end of file
diff --git a/ktransformers/util/vendors.py b/ktransformers/util/vendors.py
new file mode 100644
index 00000000..c9a709e2
--- /dev/null
+++ b/ktransformers/util/vendors.py
@@ -0,0 +1,202 @@
+from __future__ import annotations
+
+from enum import IntEnum, auto
+from typing import Optional, Union, List
+import torch
+
+class GPUVendor(IntEnum):
+ NVIDIA = auto()
+ AMD = auto()
+ MooreThreads = auto()
+ MetaX = auto()
+ MUSA = auto()
+ Unknown = auto()
+
+class DeviceManager:
+ """
+ Device manager that provides a unified interface for handling different GPU vendors
+ """
+ def __init__(self):
+ self.gpu_vendor = self._detect_gpu_vendor()
+ self.available_devices = self._get_available_devices()
+
+ def _detect_gpu_vendor(self) -> GPUVendor:
+ """Detect GPU vendor type"""
+ if not torch.cuda.is_available():
+ # Check MUSA availability (assuming a musa module exists)
+ try:
+ import musa
+ if musa.is_available():
+ return GPUVendor.MUSA
+ except (ImportError, AttributeError):
+ pass
+
+ return GPUVendor.Unknown
+
+ device_name = torch.cuda.get_device_name(0).lower()
+
+ if any(name in device_name for name in ["nvidia", "geforce", "quadro", "tesla", "titan", "rtx", "gtx"]):
+ return GPUVendor.NVIDIA
+ elif any(name in device_name for name in ["amd", "radeon", "rx", "vega", "instinct", "firepro", "mi"]):
+ return GPUVendor.AMD
+ elif any(name in device_name for name in ["mthreads", "moore", "mtt"]):
+ return GPUVendor.MooreThreads
+ elif any(name in device_name for name in ["metax", "meta"]):
+ return GPUVendor.MetaX
+ elif "musa" in device_name:
+ return GPUVendor.MUSA
+
+ # Backend check
+ try:
+ if hasattr(torch.version, 'hip') and torch.version.hip is not None:
+ return GPUVendor.AMD
+ elif hasattr(torch.version, 'cuda') and torch.version.cuda is not None:
+ return GPUVendor.NVIDIA
+ except:
+ pass
+
+ return GPUVendor.Unknown
+
+ def _get_available_devices(self) -> List[int]:
+ """Get list of available device indices"""
+ devices = []
+
+ if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD:
+ devices = list(range(torch.cuda.device_count()))
+ elif self.gpu_vendor == GPUVendor.MUSA:
+ try:
+ import musa
+ devices = list(range(musa.device_count()))
+ except (ImportError, AttributeError):
+ pass
+
+ return devices
+
+ def get_device_str(self, device_id: Union[int, str]) -> str:
+ """
+ Get device string for the given device ID
+
+ Args:
+ device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
+
+ Returns:
+ Device string representation (e.g., "cuda:0", "musa:1", "cpu")
+ """
+ if device_id == -1 or device_id == "cpu":
+ return "cpu"
+
+ if isinstance(device_id, int):
+ if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD:
+ if device_id < torch.cuda.device_count():
+ return f"cuda:{device_id}"
+ elif self.gpu_vendor == GPUVendor.MUSA:
+ try:
+ import musa
+ if device_id < musa.device_count():
+ return f"musa:{device_id}"
+ except (ImportError, AttributeError):
+ pass
+
+ return "cpu"
+
+ def to_torch_device(self, device_id: Union[int, str] = 0) -> torch.device:
+ """
+ Convert device ID to torch.device object
+
+ Args:
+ device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
+
+ Returns:
+ torch.device object
+ """
+ device_str = self.get_device_str(device_id)
+
+ # Handle MUSA device
+ if device_str.startswith("musa:"):
+ try:
+ import musa
+ index = int(device_str.split(":")[-1])
+ return musa.device(index)
+ except (ImportError, ValueError, AttributeError):
+ return torch.device("cpu")
+
+ # Standard PyTorch device
+ return torch.device(device_str)
+
+ def move_tensor_to_device(self, tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor:
+ """
+ Move tensor to specified device
+
+ Args:
+ tensor: PyTorch tensor to move
+ device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
+
+ Returns:
+ Tensor moved to the specified device
+ """
+ device = self.to_torch_device(device_id)
+ return tensor.to(device)
+
+ def is_available(self, index: int = 0) -> bool:
+ """
+ Check if device at specified index is available
+
+ Args:
+ index: Device index to check
+
+ Returns:
+ True if the device is available, False otherwise
+ """
+ if index < 0:
+ return True # CPU is always available
+
+ return index in self.available_devices
+
+ def get_all_devices(self) -> List[int]:
+ """
+ Get all available device indices
+
+ Returns:
+ List of available device indices (0, 1, 2, etc.)
+ """
+ return self.available_devices
+
+# Create global device manager instance
+device_manager = DeviceManager()
+
+# Convenience functions
+def get_device(device_id: Union[int, str] = 0) -> torch.device:
+ """
+ Get torch.device object for the specified device ID
+
+ Args:
+ device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
+
+ Returns:
+ torch.device object
+ """
+ return device_manager.to_torch_device(device_id)
+
+def to_device(tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor:
+ """
+ Move tensor to specified device
+
+ Args:
+ tensor: PyTorch tensor to move
+ device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
+
+ Returns:
+ Tensor moved to the specified device
+ """
+ return device_manager.move_tensor_to_device(tensor, device_id)
+
+# Get devices
+cpu_device = get_device(-1) # CPU using index -1
+cpu_device2 = get_device("cpu") # CPU using string "cpu"
+gpu0 = get_device(0) # First GPU
+
+# Move tensors
+x = torch.randn(3, 3)
+x_gpu = to_device(x, 0) # Move to first GPU
+x_cpu1 = to_device(x, -1) # Move to CPU using index -1
+x_cpu2 = to_device(x, "cpu") # Move to CPU using string "cpu"
\ No newline at end of file
diff --git a/setup.py b/setup.py
index ea154828..5c29b8f5 100644
--- a/setup.py
+++ b/setup.py
@@ -29,7 +29,7 @@
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
from setuptools import setup, Extension
from cpufeature.extension import CPUFeature
-from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
try:
from torch_musa.utils.simple_porting import SimplePorting
from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME
@@ -64,6 +64,70 @@ def get_musa_bare_metal_version(self, musa_dir):
musa_version = f"{bare_metal_version.major}{bare_metal_version.minor}"
return musa_version
+ def get_rocm_bare_metal_version(self, rocm_dir):
+ """
+ Get the ROCm version from the ROCm installation directory.
+
+ Args:
+ rocm_dir: Path to the ROCm installation directory
+
+ Returns:
+ A string representation of the ROCm version (e.g., "63" for ROCm 6.3)
+ """
+ try:
+ # Try using rocm_agent_enumerator to get version info
+ raw_output = subprocess.check_output(
+ [rocm_dir + "/bin/rocminfo", "--version"],
+ universal_newlines=True,
+ stderr=subprocess.STDOUT)
+ # Extract version number from output
+ match = re.search(r'(\d+\.\d+)', raw_output)
+ if match:
+ version_str = match.group(1)
+ version = parse(version_str)
+ rocm_version = f"{version.major}{version.minor}"
+ return rocm_version
+ except (subprocess.CalledProcessError, FileNotFoundError):
+ # If rocminfo --version fails, try alternative methods
+ pass
+
+ try:
+ # Try reading version from release file
+ with open(os.path.join(rocm_dir, "share/doc/hip/version.txt"), "r") as f:
+ version_str = f.read().strip()
+ version = parse(version_str)
+ rocm_version = f"{version.major}{version.minor}"
+ return rocm_version
+ except (FileNotFoundError, IOError):
+ pass
+
+ # If all else fails, try to extract from directory name
+ dir_name = os.path.basename(os.path.normpath(rocm_dir))
+ match = re.search(r'rocm-(\d+\.\d+)', dir_name)
+ if match:
+ version_str = match.group(1)
+ version = parse(version_str)
+ rocm_version = f"{version.major}{version.minor}"
+ return rocm_version
+
+ # Fallback to extracting from hipcc version
+ try:
+ raw_output = subprocess.check_output(
+ [rocm_dir + "/bin/hipcc", "--version"],
+ universal_newlines=True,
+ stderr=subprocess.STDOUT)
+ match = re.search(r'HIP version: (\d+\.\d+)', raw_output)
+ if match:
+ version_str = match.group(1)
+ version = parse(version_str)
+ rocm_version = f"{version.major}{version.minor}"
+ return rocm_version
+ except (subprocess.CalledProcessError, FileNotFoundError):
+ pass
+
+ # If we still can't determine the version, raise an error
+ raise ValueError(f"Could not determine ROCm version from directory: {rocm_dir}")
+
def get_cuda_bare_metal_version(self, cuda_dir):
raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
@@ -148,11 +212,13 @@ def get_package_version(self, full_version=False):
cpu_instruct = self.get_cpu_instruct()
backend_version = ""
if CUDA_HOME is not None:
- backend_version = f"cu{self.get_cuda_bare_metal_version(CUDA_HOME)}"
+ backend_version = f""
elif MUSA_HOME is not None:
backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}"
+ elif ROCM_HOME is not None:
+ backend_version = f"rocm{self.get_rocm_bare_metal_version(ROCM_HOME)}"
else:
- raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
+ raise ValueError("Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set.")
package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}"
if full_version:
return package_version
@@ -247,9 +313,13 @@ def build_extension(self, ext) -> None:
cmake_args += ["-DKTRANSFORMERS_USE_CUDA=ON"]
elif MUSA_HOME is not None:
cmake_args += ["-DKTRANSFORMERS_USE_MUSA=ON"]
+ elif ROCM_HOME is not None:
+ cmake_args += ["-DKTRANSFORMERS_USE_ROCM=ON"]
else:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
-
+ # log cmake_args
+ print("CMake args:", cmake_args)
+
build_args = []
if "CMAKE_ARGS" in os.environ:
cmake_args += [
@@ -328,7 +398,7 @@ def build_extension(self, ext) -> None:
["cmake", "--build", ".", "--verbose", *build_args], cwd=build_temp, check=True
)
-if CUDA_HOME is not None:
+if CUDA_HOME is not None or ROCM_HOME is not None:
ops_module = CUDAExtension('KTransformersOps', [
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu',
'ktransformers/ktransformers_ext/cuda/binding.cpp',
@@ -338,7 +408,7 @@ def build_extension(self, ext) -> None:
'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'],
'nvcc': [
'-O3',
- '--use_fast_math',
+ # '--use_fast_math',
'-Xcompiler', '-fPIC',
'-DKTRANSFORMERS_USE_CUDA',
]
@@ -371,6 +441,7 @@ def build_extension(self, ext) -> None:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
setup(
+ name=VersionInfo.PACKAGE_NAME,
version=VersionInfo().get_package_version(),
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
ext_modules=[