Skip to content

[Bug] deepseek v3/r1 with full context with balance_serve backend #1417

@magikRUKKOLA

Description

@magikRUKKOLA

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/kvcache-ai/ktransformers/discussions. Otherwise, it will be closed.
  • 5. To help the community, I will use Chinese/English or attach an Chinese/English translation if using another language. Non-Chinese/English content without translation may be closed.

Describe the bug

I am not sure how to run the balance_serve backend with a full 128k context. If I will try something like that:

ktransformers \
  --gguf_path /opt/unsloth/DeepSeek-R1-0528-GGUF/UD-Q2_K_XL/DeepSeek-R1-0528-UD-Q2_K_XL-00001-of-00006.gguf \
  --model_path /opt/unsloth/DeepSeek-R1-0528-GGUF/ \
  --model_name unsloth/DeepSeek-R1-0528-Q2-GGUF \
  --cpu_infer $(grep ^cpu\\scores /proc/cpuinfo | uniq | awk '{print $4}' | xargs -I{} echo "{}-0" | bc) \
  --optimize_config_path /opt/unsloth/DeepSeek-R1-0528-GGUF/UD-Q2_K_XL/DeepSeek-V3-Chat-serve.yaml \
  --backend_type balance_serve \
  --cache_8bit True \
  --cache_lens $((128*1024)) \
  --length $((128*1024)) \
  --max_new_tokens $((128*1024)) \
  --chunk_size 256 \
  --max_batch_size 1 \
  --temperature 0.6 \
  --top_p 0.95 \
  --host 0.0.0.0 \
  --port 8080 \
  --fast_safetensors True \
  --log_level DEBUG \
  --use_cuda_graph

I will have the error regarding the cuda_graphs.

Rebuilding kvcache
kv_cache loaded successfully.
capturing cuda graph 1 1
2025-07-03 13:06:22,408 - INFO - flashinfer.jit: Loading JIT ops: batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_profiler_False
2025-07-03 13:06:46,487 - INFO - flashinfer.jit: Finished loading JIT ops: batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_profiler_False
2025-07-03 13:06:46,536 - INFO - flashinfer.jit: Loading JIT ops: norm
2025-07-03 13:07:02,601 - INFO - flashinfer.jit: Finished loading JIT ops: norm
cuda_graph: 1/6, warmup finished.
capturing cuda graph 2 2
cuda_graph: 2/6, warmup finished.
capturing cuda graph 3 3
Process SpawnProcess-1:
Traceback (most recent call last):
  File "/usr/lib/python3.13/multiprocessing/process.py", line 313, in _bootstrap
    self.run()
    ~~~~~~~~^^
  File "/usr/lib/python3.13/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
    ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ktransformers/.ktransformers/lib/python3.13/site-packages/ktransformers/server/backend/interfaces/balance_serve.py", line 279, in run_engine
    engine.model_runner.warmup()
    ~~~~~~~~~~~~~~~~~~~~~~~~~~^^
  File "/opt/ktransformers/.ktransformers/lib/python3.13/site-packages/ktransformers/server/balance_serve/inference/model_runner.py", line 133, in warmup
    self.model_attn_plan(self.input[i], i)
    ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/opt/ktransformers/.ktransformers/lib/python3.13/site-packages/ktransformers/server/balance_serve/inference/model_runner.py", line 92, in model_attn_plan
    self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                     num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                     head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                     sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ktransformers/.ktransformers/lib/python3.13/site-packages/ktransformers/models/custom_modeling_deepseek_v3.py", line 146, in flash_infer_attn_plan
    self.wrapper.plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices,
    ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                      minibatch.kv_len, num_heads, head_dim_ckv, head_dim_kpe, page_size, causal, sm_scale, q_data_type, kv_data_type, bsz_tensors)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ktransformers/.ktransformers/lib/python3.13/site-packages/flashinfer/mla.py", line 246, in plan
    self._qo_indptr_buf[:len(qo_indptr)].copy_(qo_indptr, non_blocking=True)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 0

I would not have a similar issue in case with ktransformers backend, but it doesn't support the disk prefix caching. So what to do?

[UPDATE]: the ephemeral vibe-coded solution seems to be currently found. #1417 (comment)

Reproduction

ktransformers \
  --gguf_path /opt/unsloth/DeepSeek-R1-0528-GGUF/UD-Q2_K_XL/DeepSeek-R1-0528-UD-Q2_K_XL-00001-of-00006.gguf \
  --model_path /opt/unsloth/DeepSeek-R1-0528-GGUF/ \
  --model_name unsloth/DeepSeek-R1-0528-Q2-GGUF \
  --cpu_infer $(grep ^cpu\\scores /proc/cpuinfo | uniq | awk '{print $4}' | xargs -I{} echo "{}-0" | bc) \
  --optimize_config_path /opt/unsloth/DeepSeek-R1-0528-GGUF/UD-Q2_K_XL/DeepSeek-V3-Chat-serve.yaml \
  --backend_type balance_serve \
  --cache_8bit True \
  --cache_lens $((128*1024)) \
  --length $((128*1024)) \
  --max_new_tokens $((128*1024)) \
  --chunk_size 256 \
  --max_batch_size 1 \
  --temperature 0.6 \
  --top_p 0.95 \
  --host 0.0.0.0 \
  --port 8080 \
  --fast_safetensors True \
  --log_level DEBUG \
  --use_cuda_graph

Environment

uname -a
Linux xxx 6.12.32-amd64 #1 SMP PREEMPT_DYNAMIC Debian 6.12.32-1 (2025-06-07) x86_64 GNU/Linux
nvidia-smi
Thu Jul  3 12:32:18 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.51.02              Driver Version: 575.51.02      CUDA Version: 12.9     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3090        Off |   00000000:41:00.0 Off |                  N/A |
| 30%   35C    P8             31W /  350W |     160MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
uv pip list
Using Python 3.13.3 environment at: /opt/ktransformers/.ktransformers
Package                  Version
------------------------ ------------------------
accelerate               1.8.1
annotated-types          0.7.0
anyio                    4.9.0
blessed                  1.21.0
build                    1.2.2.post1
certifi                  2025.6.15
charset-normalizer       3.4.2
click                    8.2.1
colorlog                 6.9.0
cpufeature               0.2.1
distro                   1.9.0
fastapi                  0.115.14
filelock                 3.18.0
fire                     0.7.0
flashinfer-python        0.2.3
fsspec                   2025.5.1
greenlet                 3.2.3
h11                      0.16.0
hf-xet                   1.1.5
httpcore                 1.0.9
httpx                    0.28.1
huggingface-hub          0.33.2
idna                     3.10
jinja2                   3.1.6
jiter                    0.10.0
jsonpatch                1.33
jsonpointer              3.0.0
ktransformers            0.3.2+cu129torch29avx2
langchain                0.3.26
langchain-core           0.3.67
langchain-text-splitters 0.3.8
langsmith                0.4.4
markupsafe               2.1.5
mpmath                   1.3.0
networkx                 3.5
ninja                    1.11.1.4
numpy                    2.3.1
nvidia-cublas-cu12       12.8.4.1
nvidia-cuda-cupti-cu12   12.8.90
nvidia-cuda-nvrtc-cu12   12.8.93
nvidia-cuda-runtime-cu12 12.8.90
nvidia-cudnn-cu12        9.10.2.21
nvidia-cufft-cu12        11.3.3.83
nvidia-cufile-cu12       1.13.1.3
nvidia-curand-cu12       10.3.9.90
nvidia-cusolver-cu12     11.7.3.90
nvidia-cusparse-cu12     12.5.8.93
nvidia-cusparselt-cu12   0.7.1
nvidia-nccl-cu12         2.27.3
nvidia-nvjitlink-cu12    12.8.93
nvidia-nvshmem-cu12      3.2.5
nvidia-nvtx-cu12         12.8.90
openai                   1.92.2
orjson                   3.10.18
packaging                24.2
pillow                   11.2.1
protobuf                 6.31.1
psutil                   7.0.0
pydantic                 2.11.7
pydantic-core            2.33.2
pyproject-hooks          1.2.0
pytorch-triton           3.3.1+gitc8757738
pyyaml                   6.0.2
pyzmq                    27.0.0
regex                    2024.11.6
requests                 2.32.4
requests-toolbelt        1.0.0
safetensors              0.5.3
sentencepiece            0.2.0
setuptools               78.1.0
sniffio                  1.3.1
sqlalchemy               2.0.41
starlette                0.46.2
sympy                    1.14.0
tenacity                 9.1.2
termcolor                3.1.0
tokenizers               0.21.2
torch                    2.9.0.dev20250702+cu128
torchaudio               2.8.0.dev20250702+cu128
torchvision              0.24.0.dev20250702+cu128
tqdm                     4.67.1
transformers             4.51.3
triton                   3.3.1
typing-extensions        4.14.0
typing-inspection        0.4.1
urllib3                  2.5.0
uvicorn                  0.35.0
wcwidth                  0.2.13
wheel                    0.45.1
zmq                      0.0.0
zstandard                0.23.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions