Skip to content

support for Metax GPU #819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NOTICE
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
The following files may have been Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. in 2024.
setup.py custom_gguf.py triton_attention.py
22 changes: 16 additions & 6 deletions ktransformers/operators/triton_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
# which was originally adapted from
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
# 2025 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd.

import triton
import triton.language as tl
IS_MACA_TORCH = "metax" in torch.__version__

@triton.jit
def tanh(x):
Expand Down Expand Up @@ -218,7 +220,11 @@ def _decode_grouped_att_m_fwd(
"kpack": 2
}
"""

num_warps = 4
num_stages =2
if IS_MACA_TORCH:
num_warps = 2
num_stages =1
_fwd_grouped_kernel_stage1[grid](
q,
k_buffer,
Expand Down Expand Up @@ -247,8 +253,8 @@ def _decode_grouped_att_m_fwd(
NUM_KV_SPLITS=NUM_KV_SPLITS,
PAGE_SIZE=page_size,
logit_cap=logit_cap,
num_warps=4,
num_stages=2,
num_warps=num_warps,
num_stages=num_stages,
Lk=Lk,
Lv=Lv,
**extra_kargs,
Expand Down Expand Up @@ -336,7 +342,11 @@ def _decode_softmax_reducev_fwd(
"kpack": 2
}
"""

num_warps = 4
num_stages =2
if IS_MACA_TORCH:
num_warps = 2
num_stages =1
grid = (batch, head_num)
_fwd_kernel_stage2[grid](
logits,
Expand All @@ -350,8 +360,8 @@ def _decode_softmax_reducev_fwd(
NUM_KV_SPLITS=NUM_KV_SPLITS,
BLOCK_DV=BLOCK_DV,
Lv=Lv,
num_warps=4,
num_stages=2,
num_warps=num_warps,
num_stages=num_stages,
**extra_kargs,
)

Expand Down
27 changes: 21 additions & 6 deletions ktransformers/util/custom_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Copyright (c) 2023-2024 The ggml authors
Copyright (c) 2024 Thomas Germer
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
Copyright (c) 2025 by MetaX Integrated Circuits (Shanghai) Co., Ltd.
'''
# copied from llama.cpp/gguf-py/gguf/constants.py to satisfy dependence of gguf
# GGUF specification
Expand All @@ -28,6 +29,7 @@
from .custom_loader import SafeTensorLoader
import ctypes
import math
IS_MACA_TORCH = "metax" in torch.__version__

class GGMLQuantizationType(IntEnum):
F32 = 0
Expand Down Expand Up @@ -575,8 +577,13 @@ def dequantize_q4_k_gpu(data, device:str ="cuda", target_dtype = torch.get_defau
device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q4_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
if IS_MACA_TORCH:
data = torch.tensor(data)
c_pointer = ctypes.addressof(ctypes.cast(data.data_ptr(), ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q4_k(c_pointer, data.numel(), block_size, ele_per_blk, device, target_dtype)
else :
c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q4_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)

def dequantize_q5_k(data):
# C implementation
Expand Down Expand Up @@ -698,9 +705,14 @@ def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda", target_dtype = to
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q6_K"]
device = torch.device(device)
num_blocks = len(data) // block_size
data = np.frombuffer(data, dtype=data.dtype)
c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q6_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
if IS_MACA_TORCH:
data = torch.tensor(data)
c_pointer = ctypes.addressof(ctypes.cast(data.data_ptr(), ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q6_k(c_pointer, data.numel(), block_size, ele_per_blk, device, target_dtype)
else :
data = np.frombuffer(data, dtype=data.dtype)
c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q6_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)

kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8)

Expand Down Expand Up @@ -811,7 +823,10 @@ def dequantize_f32(data):

def dequantize_f32_gpu(data, device, target_dtype = torch.get_default_dtype()):
data = np.frombuffer(data, dtype=np.float32)
res = torch.from_numpy(data.copy())
if IS_MACA_TORCH:
res = torch.tensor(data)
else :
res = torch.from_numpy(data.copy())
res_gpu = torch.empty_like(res, device=device, dtype=target_dtype)
res_gpu.copy_(res)
return res_gpu
Expand Down
28 changes: 25 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py
Copyright (c) 2023, Tri Dao.
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
Copyright (c) 2025 by MetaX Integrated Circuits (Shanghai) Co., Ltd.
'''

import os
Expand All @@ -35,6 +36,7 @@
from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME
except ImportError:
MUSA_HOME=None
IS_MACA_TORCH = "metax" in torch.__version__

class CpuInstructInfo:
CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE")
Expand Down Expand Up @@ -319,16 +321,36 @@ def build_extension(self, ext) -> None:
build_temp = Path(ext.sourcedir) / "build"
if not build_temp.exists():
build_temp.mkdir(parents=True)
camke_cmd = "cmake"
if IS_MACA_TORCH:
camke_cmd = "cmake_maca"
result = subprocess.run(
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True , capture_output=True
[camke_cmd, ext.sourcedir, *cmake_args], cwd=build_temp, check=True , capture_output=True
)
print("Standard output:", result.stdout)
print("Standard error:", result.stderr)
subprocess.run(
["cmake", "--build", ".", "--verbose", *build_args], cwd=build_temp, check=True
[camke_cmd, "--build", ".", "--verbose", *build_args], cwd=build_temp, check=True
)

if CUDA_HOME is not None:
if IS_MACA_TORCH:
if os.path.exists("ktransformers/ktransformers_ext/cuda/binding.cpp"):
os.rename("ktransformers/ktransformers_ext/cuda/binding.cpp", "ktransformers/ktransformers_ext/cuda/binding.cu")
ops_module = CUDAExtension('KTransformersOps', [
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu',
'ktransformers/ktransformers_ext/cuda/binding.cu',
#TODO: More files can be added in the future
],
extra_compile_args={
'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'],
'cucc': [
'-O3',
'--use_fast_math',
'-Xcompiler', '-fPIC',
]
}
)
elif CUDA_HOME is not None:
ops_module = CUDAExtension('KTransformersOps', [
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu',
'ktransformers/ktransformers_ext/cuda/binding.cpp',
Expand Down