Skip to content

Add TP2 test scripts for 1P/1D #1409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#!/bin/bash

echo "Warning: LMCache disaggregated prefill support for vLLM v1 is experimental and subject to change."


PIDS=()

# Switch to the directory of the current script
cd "$(dirname "${BASH_SOURCE[0]}")"

check_hf_token() {
if [ -z "$HF_TOKEN" ]; then
echo "HF_TOKEN is not set. Please set it to your Hugging Face token."
exit 1
fi
if [[ "$HF_TOKEN" != hf_* ]]; then
echo "HF_TOKEN is not a valid Hugging Face token. Please set it to your Hugging Face token."
exit 1
fi
echo "HF_TOKEN is set and valid."
}

check_num_gpus() {
# can you check if the number of GPUs are >=2 via nvidia-smi?
num_gpus=$(hl-smi --query-gpu=name --format=csv,noheader | wc -l)
if [ "$num_gpus" -lt 4 ]; then
echo "You need at least 4 GPUs to run disaggregated prefill TP2."
exit 1
else
echo "Found $num_gpus GPUs."
fi
}

ensure_python_library_installed() {
echo "Checking if $1 is installed..."
python -c "import $1" > /dev/null 2>&1
if [ $? -ne 0 ]; then
if [ "$1" == "nixl" ]; then
echo "$1 is not installed. Please refer to https://github.com/ai-dynamo/nixl for installation."
else
echo "$1 is not installed. Please install it via pip install $1."
fi
exit 1
else
echo "$1 is installed."
fi
}

cleanup() {
echo "Stopping everything…"
trap - INT TERM # prevent re-entrancy
kill -- -$$ # negative PID == “this whole process-group”
wait # reap children so we don't leave zombies
exit 0
}

wait_for_server() {
local port=$1
local timeout_seconds=1200
local start_time=$(date +%s)

echo "Waiting for server on port $port..."

while true; do
if curl -s "localhost:${port}/v1/completions" > /dev/null; then
return 0
fi

local now=$(date +%s)
if (( now - start_time >= timeout_seconds )); then
echo "Timeout waiting for server"
return 1
fi

sleep 1
done
}


main() {
#check_hf_token
check_num_gpus
ensure_python_library_installed lmcache
#ensure_python_library_installed nixl
ensure_python_library_installed pandas
ensure_python_library_installed datasets
ensure_python_library_installed vllm

trap cleanup INT
trap cleanup USR1
trap cleanup TERM

echo "Launching prefiller, decoder and proxy..."
echo "Please check prefiller.log, decoder.log and proxy.log for logs."

echo "starting lmcache "
python -m lmcache.v1.server localhost 8100 2>&1 &
echo "start prefiller "
bash disagg_vllm_launcher_gaudi_lm_tp2.sh prefiller \
> >(tee prefiller.log) 2>&1 &
prefiller_pid=$!
PIDS+=($prefiller_pid)
echo "start decoder "
bash disagg_vllm_launcher_gaudi_lm_tp2.sh decoder \
> >(tee decoder.log) 2>&1 &
decoder_pid=$!
PIDS+=($decoder_pid)

python3 disagg_proxy_server.py \
--host localhost \
--port 1000 \
--prefiller-host localhost \
--prefiller-port 1100 \
--decoder-host localhost \
--decoder-port 1200 \
> >(tee proxy.log) 2>&1 &
proxy_pid=$!
PIDS+=($proxy_pid)

wait_for_server 1100
wait_for_server 1200
wait_for_server 1000

echo "All servers are up. Starting benchmark..."

# begin benchmark
cd ../../../benchmarks/
python benchmark_serving.py --port 1000 --seed $(date +%s) \
--model /mnt/weka/data/pytorch/llama3.1/Meta-Llama-3.1-8B-Instruct/ \
--dataset-name random --random-input-len 8000 --random-output-len 200 \
--num-prompts 100 --burstiness 100 --request-rate 3.6 | tee benchmark.log

echo "Benchmarking done. Cleaning up..."

cleanup

}

main
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#!/bin/bash

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"

if [[ $# -lt 1 ]]; then
echo "Usage: $0 <prefiller | decoder> [model]"
exit 1
fi

if [[ $# -eq 1 ]]; then
echo "Using default model: meta-llama/Llama-3.1-8B-Instruct"
MODEL="/mnt/weka/data/pytorch/llama3.1/Meta-Llama-3.1-8B-Instruct/"
else
echo "Using model: $2"
MODEL=$2
fi


if [[ $1 == "prefiller" ]]; then
# Prefiller listens on port 8100
prefill_config_file=$SCRIPT_DIR/configs/lmcache-config-lm.yaml

#UCX_TLS=tcp \
LMCACHE_CONFIG_FILE=$prefill_config_file \
VLLM_ENABLE_V1_MULTIPROCESSING=1 \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
LMCACHE_REMOTE_SERDE=naive \
LMCACHE_CHUNK_SIZE=256 \
vllm serve $MODEL \
--port 1100 \
--gpu_memory_utilization 0.5 \
--disable-log-requests \
--tensor_parallel_size 2 \
--kv-transfer-config \
'{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_producer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "producer1"}}'


elif [[ $1 == "decoder" ]]; then
# Decoder listens on port 8200
decode_config_file=$SCRIPT_DIR/configs/lmcache-config-lm.yaml

#UCX_TLS=tcp \
LMCACHE_CONFIG_FILE=$decode_config_file \
VLLM_ENABLE_V1_MULTIPROCESSING=1 \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
LMCACHE_REMOTE_SERDE=naive \
LMCACHE_CHUNK_SIZE=256 \
vllm serve $MODEL \
--port 1200 \
--gpu_memory_utilization 0.5 \
--disable-log-requests \
--tensor_parallel_size 2 \
--kv-transfer-config \
'{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_consumer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "consumer1"}}'


else
echo "Invalid role: $1"
echo "Should be either prefill, decode"
exit 1
fi
132 changes: 132 additions & 0 deletions examples/lmcache/kv_cache_sharing_lmcache_v1_lm_tp2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# SPDX-License-Identifier: Apache-2.0
"""
This file demonstrates the example usage of remote KV cache sharing
with LMCache.
We will launch 2 vllm instances, and launch an additional LMCache server.
KV cache is transferred in the following manner:
(1) vLLM instance 1 -> LMCache server (KV cache store).
(2) LMCache server -> vLLM instance 2 (KV cache reuse/retrieve).

Note that lmcache needs to be installed to run this example.
Learn more about LMCache in https://github.com/LMCache/LMCache.
"""
import os
import subprocess
import time
from multiprocessing import Event, Process

from lmcache.integration.vllm.utils import ENGINE_NAME
from lmcache.v1.cache_engine import LMCacheEngineBuilder

from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig

# LMCache-related environment variables
# The port to start LMCache server
port = 8100
# LMCache is set to use 256 tokens per chunk
os.environ["LMCACHE_CHUNK_SIZE"] = "256"
# Disable local CPU backend in LMCache
os.environ["LMCACHE_LOCAL_CPU"] = "False"
# Set local CPU memory buffer limit to 5.0 GB
os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0"
# Set the remote URL for LMCache server
os.environ["LMCACHE_REMOTE_URL"] = f"lm://localhost:{port}"
# Set the serializer/deserializer between vllm and LMCache server
# `naive` indicates using raw bytes of the tensor without any compression
os.environ["LMCACHE_REMOTE_SERDE"] = "naive"

MODEL = "/software/data/pytorch/huggingface/hub/models--meta-llama--Llama-3.2-1B-Instruct/snapshots/9213176726f574b556790deb65791e0c5aa438b6/"

Check failure on line 39 in examples/lmcache/kv_cache_sharing_lmcache_v1_lm_tp2.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

examples/lmcache/kv_cache_sharing_lmcache_v1_lm_tp2.py:39:81: E501 Line too long (142 > 80)
#prompts = [
# "Hello, how are you?" * 1000,
#]
prompts = [
"San Francisco is a",
]


def run_store(store_done, prompts):
# We use GPU 0 for KV cache store process.
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)

ktc = KVTransferConfig.from_cli(
'{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_producer"}')
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory.
llm = LLM(model=MODEL,
kv_transfer_config=ktc,
max_model_len=8000,
gpu_memory_utilization=0.8,
tensor_parallel_size=2,
enforce_eager=False)

outputs = llm.generate(prompts, sampling_params)
for output in outputs:
generated_text = output.outputs[0].text
print(f"Producer Generated text: {generated_text!r}")
print("KV cache store is finished.")
store_done.set()

# Clean up lmcache backend
LMCacheEngineBuilder.destroy(ENGINE_NAME)


def run_retrieve(store_done, prompts, timeout=1):
# We use GPU 1 for KV cache retrieve process.
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=20)
# sampling_params = SamplingParams(temperature=0, max_tokens=100)
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_consumer"}')
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# of memory. Reduce the value if your GPU has less memory.
llm = LLM(model=MODEL,
kv_transfer_config=ktc,
max_model_len=8000,
gpu_memory_utilization=0.8,
tensor_parallel_size=2,
enforce_eager=False)

print("Waiting for KV cache store to finish...")
store_done.wait()
time.sleep(timeout)

outputs = llm.generate(prompts, sampling_params)
for output in outputs:
generated_text = output.outputs[0].text
print(f"Consumer Generated text: {generated_text!r}")

# Clean up lmcache backend
LMCacheEngineBuilder.destroy(ENGINE_NAME)


def run_lmcache_server(port):
server_proc = subprocess.Popen(
["python", "-m", "lmcache.v1.server", "localhost",
str(port)])
return server_proc


def main():
store_done = Event()
store_process = Process(target=run_store, args=(store_done, prompts))
retrieve_process = Process(target=run_retrieve, args=(store_done, prompts))

lmcache_server_process = run_lmcache_server(port)
print("libin kvshare store start")
# Start KV cache store process
store_process.start()

print("libin kvshare retrieve start")
# Start KV cache retrieve process
retrieve_process.start()
print("libin kvshare retrieve done")
store_process.join()
retrieve_process.join()
# Clean up the processes
retrieve_process.terminate()
lmcache_server_process.terminate()
lmcache_server_process.wait()


if __name__ == "__main__":
main()
Loading