-
Notifications
You must be signed in to change notification settings - Fork 1k
Open
Description
Checklist
- 1. I have searched related issues but cannot get the expected help.
- 2. The bug has not been fixed in the latest version.
- 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
- 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/kvcache-ai/ktransformers/discussions. Otherwise, it will be closed.
- 5. To help the community, I will use Chinese/English or attach an Chinese/English translation if using another language. Non-Chinese/English content without translation may be closed.
Describe the bug
I am not sure how to run the balance_serve backend with a full 128k context. If I will try something like that:
ktransformers \
--gguf_path /opt/unsloth/DeepSeek-R1-0528-GGUF/UD-Q2_K_XL/DeepSeek-R1-0528-UD-Q2_K_XL-00001-of-00006.gguf \
--model_path /opt/unsloth/DeepSeek-R1-0528-GGUF/ \
--model_name unsloth/DeepSeek-R1-0528-Q2-GGUF \
--cpu_infer $(grep ^cpu\\scores /proc/cpuinfo | uniq | awk '{print $4}' | xargs -I{} echo "{}-0" | bc) \
--optimize_config_path /opt/unsloth/DeepSeek-R1-0528-GGUF/UD-Q2_K_XL/DeepSeek-V3-Chat-serve.yaml \
--backend_type balance_serve \
--cache_8bit True \
--cache_lens $((128*1024)) \
--length $((128*1024)) \
--max_new_tokens $((128*1024)) \
--chunk_size 256 \
--max_batch_size 1 \
--temperature 0.6 \
--top_p 0.95 \
--host 0.0.0.0 \
--port 8080 \
--fast_safetensors True \
--log_level DEBUG \
--use_cuda_graph
I will have the error regarding the cuda_graphs.
Rebuilding kvcache
kv_cache loaded successfully.
capturing cuda graph 1 1
2025-07-03 13:06:22,408 - INFO - flashinfer.jit: Loading JIT ops: batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_profiler_False
2025-07-03 13:06:46,487 - INFO - flashinfer.jit: Finished loading JIT ops: batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_profiler_False
2025-07-03 13:06:46,536 - INFO - flashinfer.jit: Loading JIT ops: norm
2025-07-03 13:07:02,601 - INFO - flashinfer.jit: Finished loading JIT ops: norm
cuda_graph: 1/6, warmup finished.
capturing cuda graph 2 2
cuda_graph: 2/6, warmup finished.
capturing cuda graph 3 3
Process SpawnProcess-1:
Traceback (most recent call last):
File "/usr/lib/python3.13/multiprocessing/process.py", line 313, in _bootstrap
self.run()
~~~~~~~~^^
File "/usr/lib/python3.13/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ktransformers/.ktransformers/lib/python3.13/site-packages/ktransformers/server/backend/interfaces/balance_serve.py", line 279, in run_engine
engine.model_runner.warmup()
~~~~~~~~~~~~~~~~~~~~~~~~~~^^
File "/opt/ktransformers/.ktransformers/lib/python3.13/site-packages/ktransformers/server/balance_serve/inference/model_runner.py", line 133, in warmup
self.model_attn_plan(self.input[i], i)
~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
File "/opt/ktransformers/.ktransformers/lib/python3.13/site-packages/ktransformers/server/balance_serve/inference/model_runner.py", line 92, in model_attn_plan
self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ktransformers/.ktransformers/lib/python3.13/site-packages/ktransformers/models/custom_modeling_deepseek_v3.py", line 146, in flash_infer_attn_plan
self.wrapper.plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices,
~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
minibatch.kv_len, num_heads, head_dim_ckv, head_dim_kpe, page_size, causal, sm_scale, q_data_type, kv_data_type, bsz_tensors)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ktransformers/.ktransformers/lib/python3.13/site-packages/flashinfer/mla.py", line 246, in plan
self._qo_indptr_buf[:len(qo_indptr)].copy_(qo_indptr, non_blocking=True)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 0
I would not have a similar issue in case with ktransformers backend, but it doesn't support the disk prefix caching. So what to do?
[UPDATE]: the ephemeral vibe-coded solution seems to be currently found. #1417 (comment)
Reproduction
ktransformers \
--gguf_path /opt/unsloth/DeepSeek-R1-0528-GGUF/UD-Q2_K_XL/DeepSeek-R1-0528-UD-Q2_K_XL-00001-of-00006.gguf \
--model_path /opt/unsloth/DeepSeek-R1-0528-GGUF/ \
--model_name unsloth/DeepSeek-R1-0528-Q2-GGUF \
--cpu_infer $(grep ^cpu\\scores /proc/cpuinfo | uniq | awk '{print $4}' | xargs -I{} echo "{}-0" | bc) \
--optimize_config_path /opt/unsloth/DeepSeek-R1-0528-GGUF/UD-Q2_K_XL/DeepSeek-V3-Chat-serve.yaml \
--backend_type balance_serve \
--cache_8bit True \
--cache_lens $((128*1024)) \
--length $((128*1024)) \
--max_new_tokens $((128*1024)) \
--chunk_size 256 \
--max_batch_size 1 \
--temperature 0.6 \
--top_p 0.95 \
--host 0.0.0.0 \
--port 8080 \
--fast_safetensors True \
--log_level DEBUG \
--use_cuda_graph
Environment
uname -a
Linux xxx 6.12.32-amd64 #1 SMP PREEMPT_DYNAMIC Debian 6.12.32-1 (2025-06-07) x86_64 GNU/Linux
nvidia-smi
Thu Jul 3 12:32:18 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.51.02 Driver Version: 575.51.02 CUDA Version: 12.9 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 3090 Off | 00000000:41:00.0 Off | N/A |
| 30% 35C P8 31W / 350W | 160MiB / 24576MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
uv pip list
Using Python 3.13.3 environment at: /opt/ktransformers/.ktransformers
Package Version
------------------------ ------------------------
accelerate 1.8.1
annotated-types 0.7.0
anyio 4.9.0
blessed 1.21.0
build 1.2.2.post1
certifi 2025.6.15
charset-normalizer 3.4.2
click 8.2.1
colorlog 6.9.0
cpufeature 0.2.1
distro 1.9.0
fastapi 0.115.14
filelock 3.18.0
fire 0.7.0
flashinfer-python 0.2.3
fsspec 2025.5.1
greenlet 3.2.3
h11 0.16.0
hf-xet 1.1.5
httpcore 1.0.9
httpx 0.28.1
huggingface-hub 0.33.2
idna 3.10
jinja2 3.1.6
jiter 0.10.0
jsonpatch 1.33
jsonpointer 3.0.0
ktransformers 0.3.2+cu129torch29avx2
langchain 0.3.26
langchain-core 0.3.67
langchain-text-splitters 0.3.8
langsmith 0.4.4
markupsafe 2.1.5
mpmath 1.3.0
networkx 3.5
ninja 1.11.1.4
numpy 2.3.1
nvidia-cublas-cu12 12.8.4.1
nvidia-cuda-cupti-cu12 12.8.90
nvidia-cuda-nvrtc-cu12 12.8.93
nvidia-cuda-runtime-cu12 12.8.90
nvidia-cudnn-cu12 9.10.2.21
nvidia-cufft-cu12 11.3.3.83
nvidia-cufile-cu12 1.13.1.3
nvidia-curand-cu12 10.3.9.90
nvidia-cusolver-cu12 11.7.3.90
nvidia-cusparse-cu12 12.5.8.93
nvidia-cusparselt-cu12 0.7.1
nvidia-nccl-cu12 2.27.3
nvidia-nvjitlink-cu12 12.8.93
nvidia-nvshmem-cu12 3.2.5
nvidia-nvtx-cu12 12.8.90
openai 1.92.2
orjson 3.10.18
packaging 24.2
pillow 11.2.1
protobuf 6.31.1
psutil 7.0.0
pydantic 2.11.7
pydantic-core 2.33.2
pyproject-hooks 1.2.0
pytorch-triton 3.3.1+gitc8757738
pyyaml 6.0.2
pyzmq 27.0.0
regex 2024.11.6
requests 2.32.4
requests-toolbelt 1.0.0
safetensors 0.5.3
sentencepiece 0.2.0
setuptools 78.1.0
sniffio 1.3.1
sqlalchemy 2.0.41
starlette 0.46.2
sympy 1.14.0
tenacity 9.1.2
termcolor 3.1.0
tokenizers 0.21.2
torch 2.9.0.dev20250702+cu128
torchaudio 2.8.0.dev20250702+cu128
torchvision 0.24.0.dev20250702+cu128
tqdm 4.67.1
transformers 4.51.3
triton 3.3.1
typing-extensions 4.14.0
typing-inspection 0.4.1
urllib3 2.5.0
uvicorn 0.35.0
wcwidth 0.2.13
wheel 0.45.1
zmq 0.0.0
zstandard 0.23.0
Metadata
Metadata
Assignees
Labels
No labels