From e53aafb8a37cf599b53aa51d00e1e8c896534434 Mon Sep 17 00:00:00 2001 From: xiajunshi Date: Tue, 4 Mar 2025 16:17:34 +0800 Subject: [PATCH 1/4] support for Metax GPU --- NOTICE | 2 ++ ktransformers/operators/triton_attention.py | 22 +++++++++++----- ktransformers/util/custom_gguf.py | 27 +++++++++++++++----- setup.py | 28 ++++++++++++++++++--- 4 files changed, 64 insertions(+), 15 deletions(-) create mode 100644 NOTICE diff --git a/NOTICE b/NOTICE new file mode 100644 index 00000000..2110a58b --- /dev/null +++ b/NOTICE @@ -0,0 +1,2 @@ +The following files may have been Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. in 2024. +setup.py custom_gguf.py triton_attention.py \ No newline at end of file diff --git a/ktransformers/operators/triton_attention.py b/ktransformers/operators/triton_attention.py index 44375206..4bc792f7 100644 --- a/ktransformers/operators/triton_attention.py +++ b/ktransformers/operators/triton_attention.py @@ -3,9 +3,11 @@ # which was originally adapted from # https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py # https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py +# 2025 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. import triton import triton.language as tl +IS_MACA_TORCH = "metax" in torch.__version__ @triton.jit def tanh(x): @@ -218,7 +220,11 @@ def _decode_grouped_att_m_fwd( "kpack": 2 } """ - + num_warps = 4 + num_stages =2 + if IS_MACA_TORCH: + num_warps = 2 + num_stages =1 _fwd_grouped_kernel_stage1[grid]( q, k_buffer, @@ -247,8 +253,8 @@ def _decode_grouped_att_m_fwd( NUM_KV_SPLITS=NUM_KV_SPLITS, PAGE_SIZE=page_size, logit_cap=logit_cap, - num_warps=4, - num_stages=2, + num_warps=num_warps, + num_stages=num_stages, Lk=Lk, Lv=Lv, **extra_kargs, @@ -336,7 +342,11 @@ def _decode_softmax_reducev_fwd( "kpack": 2 } """ - + num_warps = 4 + num_stages =2 + if IS_MACA_TORCH: + num_warps = 2 + num_stages =1 grid = (batch, head_num) _fwd_kernel_stage2[grid]( logits, @@ -350,8 +360,8 @@ def _decode_softmax_reducev_fwd( NUM_KV_SPLITS=NUM_KV_SPLITS, BLOCK_DV=BLOCK_DV, Lv=Lv, - num_warps=4, - num_stages=2, + num_warps=num_warps, + num_stages=num_stages, **extra_kargs, ) diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index 84ada15a..5b430253 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -11,6 +11,7 @@ Copyright (c) 2023-2024 The ggml authors Copyright (c) 2024 Thomas Germer Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +Copyright (c) 2025 by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. ''' # copied from llama.cpp/gguf-py/gguf/constants.py to satisfy dependence of gguf # GGUF specification @@ -28,6 +29,7 @@ from .custom_loader import SafeTensorLoader import ctypes import math +IS_MACA_TORCH = "metax" in torch.__version__ class GGMLQuantizationType(IntEnum): F32 = 0 @@ -575,8 +577,13 @@ def dequantize_q4_k_gpu(data, device:str ="cuda", target_dtype = torch.get_defau device = torch.device(device) # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. - c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) - return KTransformersOps.dequantize_q4_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) + if IS_MACA_TORCH: + data = torch.tensor(data) + c_pointer = ctypes.addressof(ctypes.cast(data.data_ptr(), ctypes.POINTER(ctypes.c_int8)).contents) + return KTransformersOps.dequantize_q4_k(c_pointer, data.numel(), block_size, ele_per_blk, device, target_dtype) + else : + c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) + return KTransformersOps.dequantize_q4_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) def dequantize_q5_k(data): # C implementation @@ -698,9 +705,14 @@ def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda", target_dtype = to ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q6_K"] device = torch.device(device) num_blocks = len(data) // block_size - data = np.frombuffer(data, dtype=data.dtype) - c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) - return KTransformersOps.dequantize_q6_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) + if IS_MACA_TORCH: + data = torch.tensor(data) + c_pointer = ctypes.addressof(ctypes.cast(data.data_ptr(), ctypes.POINTER(ctypes.c_int8)).contents) + return KTransformersOps.dequantize_q6_k(c_pointer, data.numel(), block_size, ele_per_blk, device, target_dtype) + else : + data = np.frombuffer(data, dtype=data.dtype) + c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) + return KTransformersOps.dequantize_q6_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8) @@ -811,7 +823,10 @@ def dequantize_f32(data): def dequantize_f32_gpu(data, device, target_dtype = torch.get_default_dtype()): data = np.frombuffer(data, dtype=np.float32) - res = torch.from_numpy(data.copy()) + if IS_MACA_TORCH: + res = torch.tensor(data) + else : + res = torch.from_numpy(data.copy()) res_gpu = torch.empty_like(res, device=device, dtype=target_dtype) res_gpu.copy_(res) return res_gpu diff --git a/setup.py b/setup.py index ea154828..1284c499 100644 --- a/setup.py +++ b/setup.py @@ -11,6 +11,7 @@ https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py Copyright (c) 2023, Tri Dao. Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +Copyright (c) 2025 by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. ''' import os @@ -35,6 +36,7 @@ from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME except ImportError: MUSA_HOME=None +IS_MACA_TORCH = "metax" in torch.__version__ class CpuInstructInfo: CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE") @@ -319,16 +321,36 @@ def build_extension(self, ext) -> None: build_temp = Path(ext.sourcedir) / "build" if not build_temp.exists(): build_temp.mkdir(parents=True) + camke_cmd = "cmake" + if IS_MACA_TORCH: + camke_cmd = "cmake_maca" result = subprocess.run( - ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True , capture_output=True + [camke_cmd, ext.sourcedir, *cmake_args], cwd=build_temp, check=True , capture_output=True ) print("Standard output:", result.stdout) print("Standard error:", result.stderr) subprocess.run( - ["cmake", "--build", ".", "--verbose", *build_args], cwd=build_temp, check=True + [camke_cmd, "--build", ".", "--verbose", *build_args], cwd=build_temp, check=True ) -if CUDA_HOME is not None: +if IS_MACA_TORCH: + if os.path.exists("ktransformers/ktransformers_ext/cuda/binding.cpp"): + os.rename("ktransformers/ktransformers_ext/cuda/binding.cpp", "ktransformers/ktransformers_ext/cuda/binding.cu") + ops_module = CUDAExtension('KTransformersOps', [ + 'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu', + 'ktransformers/ktransformers_ext/cuda/binding.cu', + #TODO: More files can be added in the future + ], + extra_compile_args={ + 'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'], + 'cucc': [ + '-O3', + '--use_fast_math', + '-Xcompiler', '-fPIC', + ] + } + ) +elif CUDA_HOME is not None: ops_module = CUDAExtension('KTransformersOps', [ 'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu', 'ktransformers/ktransformers_ext/cuda/binding.cpp', From b1a5e498be35326488696f5e3b016f8954c32502 Mon Sep 17 00:00:00 2001 From: edterxj <33508160+xiajunshi@users.noreply.github.com> Date: Thu, 6 Mar 2025 16:45:58 +0800 Subject: [PATCH 2/4] Update triton_attention.py remove "All Rights Reserved " --- ktransformers/operators/triton_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ktransformers/operators/triton_attention.py b/ktransformers/operators/triton_attention.py index 4bc792f7..f617870f 100644 --- a/ktransformers/operators/triton_attention.py +++ b/ktransformers/operators/triton_attention.py @@ -3,7 +3,7 @@ # which was originally adapted from # https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py # https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py -# 2025 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. +# 2025 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. import triton import triton.language as tl From f0912f23344b3800d4fea00237277d88d0f07e6e Mon Sep 17 00:00:00 2001 From: edterxj <33508160+xiajunshi@users.noreply.github.com> Date: Thu, 6 Mar 2025 16:46:36 +0800 Subject: [PATCH 3/4] Update custom_gguf.py remove "All Rights Reserved." --- ktransformers/util/custom_gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index 5b430253..2efc7379 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -11,7 +11,7 @@ Copyright (c) 2023-2024 The ggml authors Copyright (c) 2024 Thomas Germer Copyright (c) 2024 by KVCache.AI, All Rights Reserved. -Copyright (c) 2025 by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. +Copyright (c) 2025 by MetaX Integrated Circuits (Shanghai) Co., Ltd. ''' # copied from llama.cpp/gguf-py/gguf/constants.py to satisfy dependence of gguf # GGUF specification From 43d4ee893c1b42986f012d8bedc3beb8db41de2c Mon Sep 17 00:00:00 2001 From: edterxj <33508160+xiajunshi@users.noreply.github.com> Date: Thu, 6 Mar 2025 16:47:17 +0800 Subject: [PATCH 4/4] Update setup.py remove "All Rights Reserved." --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 1284c499..f6e2b76a 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py Copyright (c) 2023, Tri Dao. Copyright (c) 2024 by KVCache.AI, All Rights Reserved. -Copyright (c) 2025 by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. +Copyright (c) 2025 by MetaX Integrated Circuits (Shanghai) Co., Ltd. ''' import os