Skip to content

[P/D] Support CPU Transfer in NixlConnector #18293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
b3e7146
tpu: Support CPU Transfer in NixlConnector
juncgu Jun 12, 2025
6eb93e2
Merge branch 'main' into tpu_nixl_merge
juncgu Jun 18, 2025
d8041e6
fix device_index
juncgu Jun 19, 2025
98562fb
Merge branch 'main' into tpu_nixl_merge
juncgu Jun 19, 2025
16b3884
Merge branch 'main' into tpu_nixl_merge
richardsliu Jun 23, 2025
62b4460
fix error
richardsliu Jun 23, 2025
382588e
Merge branch 'main' into tpu_nixl_merge
juncgu Jun 23, 2025
b1ec962
fix comments
juncgu Jun 24, 2025
5a1e892
Merge branch 'main' into tpu_nixl_merge
juncgu Jun 24, 2025
0995bbd
fix recving_meta at decode
juncgu Jun 24, 2025
71cf953
tweaks
juncgu Jun 25, 2025
0ab79dc
tweak
juncgu Jun 25, 2025
4961a32
fix comments
juncgu Jun 26, 2025
caf314f
Merge branch 'main' into tpu_nixl_merge
juncgu Jun 26, 2025
e6c61de
fix for multi-forwards in a single input batch
juncgu Jun 26, 2025
0fa99d4
Merge branch 'main' into tpu_nixl_merge
juncgu Jun 26, 2025
d402f52
Merge branch 'main' into tpu_nixl_merge
juncgu Jun 26, 2025
73d4ff3
fix comments
juncgu Jun 30, 2025
6fa4861
Merge branch 'main' into tpu_nixl_merge
juncgu Jun 30, 2025
33bbfdd
Merge branch 'main' into tpu_nixl_merge
juncgu Jul 1, 2025
787e9ad
revise h2d/d2h attributes
juncgu Jul 2, 2025
6fa07b7
rename ReqMeta attributes for use_host_buffer
juncgu Jul 3, 2025
9877752
revise _NIXL_SUPPORTED_XPUS
juncgu Jul 7, 2025
e216151
SPDX license
juncgu Jul 7, 2025
0637bdf
use tp_rank for device_id in nixl data block
juncgu Jul 7, 2025
0ac9a6f
use tp_rank for device_id in nixl data block
juncgu Jul 7, 2025
e55ce7d
Merge branch 'main' into tpu_nixl_merge
juncgu Jul 8, 2025
2ba3daa
rm redundant code
juncgu Jul 9, 2025
4e72df9
fix assertation
juncgu Jul 9, 2025
b875e28
Merge branch 'main' into tpu_nixl_merge
juncgu Jul 9, 2025
9f8280e
Merge branch 'main' into tpu_nixl_merge
juncgu Jul 10, 2025
aca4246
update tpu worker & model_runner
juncgu Jul 11, 2025
1d9bc01
Merge branch 'main' into tpu_nixl_merge
juncgu Jul 11, 2025
506f3dd
Merge branch 'main' into tpu_nixl_merge
juncgu Jul 14, 2025
bd19d33
Merge branch 'main' into tpu_nixl_merge
juncgu Jul 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/tpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ jinja2>=3.1.6
ray[default]
ray[data]
setuptools==78.1.0
nixl==0.3.0

# Install torch_xla
--pre
Expand Down
162 changes: 162 additions & 0 deletions tests/v1/kv_connector/nixl_integration/run_tpu_disagg_accuracy_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#!/bin/bash
set -xe

# Hosts / ports
PREFILL_HOST=${PREFILL_HOST:-"localhost"}
PREFILL_PORT=${PREFILL_PORT:-8100}
PREFILL_NIXL_SIDE_PORT=${PREFILL_NIXL_SIDE_PORT:-5577}
DECODE_HOST=${DECODE_HOST:-"localhost"}
DECODE_PORT=${DECODE_PORT:-8200}
PROXY_HOST=${PROXY_HOST:-"localhost"}
PROXY_PORT=${PROXY_PORT:-8192}
BASELINE_HOST=${BASELINE_HOST:-"localhost"}
BASELINE_PORT=${BASELINE_PORT:-9290}


# Model to run.
MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-3B-Instruct"}
MAX_MODEL_LEN=${MAX_MODEL_LEN:-1024}
BLOCK_SIZE=${BLOCK_SIZE:-32}


# execution env
GIT_ROOT=$(git rev-parse --show-toplevel)
EXP_ROOT="${GIT_ROOT}/tests/v1/kv_connector/nixl_integration"
CONDA_PATH=${CONDA_PATH:-"/home/${USER}/anaconda3"}
CONDA_ENV_NAME=${CONDA_ENV_NAME:-"nixl"}

OUTPUT_FILE=${OUTPUT_FILE:-"${EXP_ROOT}/.tpu_accuracy_test_outputs.txt"}

# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT


# Waits for vLLM server to start.
wait_for_server() {
local host=$1
local port=$2
timeout 1200 bash -c "
until curl -s ${host}:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}

# Cleanup function
cleanup() {
echo "Caught Ctrl+C, cleaning up..."
# Cleanup commands
pgrep python | xargs kill -9 || true
# pkill -f python || true
echo "Cleanup complete. Exiting."
}

launch_baseline() {
BASELINE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
VLLM_LOGGING_LEVEL=DEBUG \
VLLM_USE_V1=1 \
PJRT_DEVICE=TPU \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
--host ${BASELINE_HOST} \
--port ${BASELINE_PORT} \
--max-model-len ${MAX_MODEL_LEN}\
--seed 42 \
--block-size ${BLOCK_SIZE} \
--gpu-memory-utilization 0.5 \
--disable-log-requests \
--enforce-eager"
echo ${BASELINE_BASE_CMD}
ssh -tt ${BASELINE_HOST} "${BASELINE_BASE_CMD}" &
}

launch_pd() {
PREFILL_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
UCX_TLS=tcp \
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
VLLM_LOGGING_LEVEL=DEBUG \
VLLM_USE_V1=1 \
VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \
VLLM_NIXL_SIDE_CHANNEL_PORT=${PREFILL_NIXL_SIDE_PORT} \
PJRT_DEVICE=TPU \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
--host ${PREFILL_HOST} \
--port ${PREFILL_PORT} \
--max-model-len ${MAX_MODEL_LEN}\
--seed 42 \
--block-size ${BLOCK_SIZE} \
--enforce-eager \
--gpu-memory-utilization 0.5 \
--disable-log-requests \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"


DECODE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
UCX_TLS=tcp \
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
VLLM_LOGGING_LEVEL=DEBUG \
VLLM_USE_V1=1 \
PJRT_DEVICE=TPU \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
--host ${DECODE_HOST} \
--port ${DECODE_PORT} \
--max-model-len ${MAX_MODEL_LEN}\
--seed 42 \
--block-size ${BLOCK_SIZE} \
--enforce-eager \
--gpu-memory-utilization 0.5 \
--disable-log-requests \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"

echo ${PREFILL_BASE_CMD}
echo ${DECODE_BASE_CMD}
sleep 2

# execute on hosts
ssh -tt ${PREFILL_HOST} "${PREFILL_BASE_CMD}" &
ssh -tt ${DECODE_HOST} "${DECODE_BASE_CMD}" &
sleep 1
wait_for_server ${PREFILL_HOST} ${PREFILL_PORT}
sleep 1
wait_for_server ${DECODE_HOST} ${DECODE_PORT}
sleep 1
}

launch_pd_proxy(){
PROXY_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
python3 ${EXP_ROOT}/toy_proxy_server.py \
--prefiller-host ${PREFILL_HOST} --prefiller-port ${PREFILL_PORT} \
--decoder-host ${DECODE_HOST} --decoder-port ${DECODE_PORT} \
--host=${PROXY_HOST} --port ${PROXY_PORT}"
echo ${PROXY_BASE_CMD}
ssh -tt ${PROXY_HOST} "${PROXY_BASE_CMD}" &
}

run_tests(){
local service_url=$1
local mode=$2
python3 ${EXP_ROOT}/test_disagg_accuracy.py --service_url=${service_url} --model_name=${MODEL_NAME} --mode=${mode} --file_name=${OUTPUT_FILE}
}


# run non-disagg. baseline & save outputs
launch_baseline
sleep 2
wait_for_server ${BASELINE_HOST} ${BASELINE_PORT}
run_tests "http://${BASELINE_HOST}:${BASELINE_PORT}" "baseline"
cleanup
sleep 10


# run disagg. & do exact-match with the outputs from baseline
launch_pd
launch_pd_proxy
sleep 10
run_tests "http://${PROXY_HOST}:${PROXY_PORT}" "disagg"
echo "-----P/D success----"

rm ${OUTPUT_FILE}
cleanup

exit 0
128 changes: 128 additions & 0 deletions tests/v1/kv_connector/nixl_integration/run_tpu_edge_case_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#!/bin/bash
set -xe

# Hosts / ports
PREFILL_HOST=${PREFILL_HOST:-"localhost"}
PREFILL_PORT=${PREFILL_PORT:-8100}
PREFILL_NIXL_SIDE_PORT=${PREFILL_NIXL_SIDE_PORT:-5577}
DECODE_HOST=${DECODE_HOST:-"localhost"}
DECODE_PORT=${DECODE_PORT:-8200}
PROXY_HOST=${PROXY_HOST:-"localhost"}
PROXY_PORT=${PROXY_PORT:-8192}
BASELINE_HOST=${BASELINE_HOST:-"localhost"}
BASELINE_PORT=${BASELINE_PORT:-9290}


# Model to run.
MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-3B-Instruct"}
MAX_MODEL_LEN=${MAX_MODEL_LEN:-1024}
BLOCK_SIZE=${BLOCK_SIZE:-32}


# execution env
GIT_ROOT=$(git rev-parse --show-toplevel)
EXP_ROOT="${GIT_ROOT}/tests/v1/kv_connector/nixl_integration"
CONDA_PATH=${CONDA_PATH:-"/home/${USER}/anaconda3"}
CONDA_ENV_NAME=${CONDA_ENV_NAME:-"nixl"}

OUTPUT_FILE=${OUTPUT_FILE:-"${EXP_ROOT}/.tpu_accuracy_test_outputs.txt"}

# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT

# Waits for vLLM server to start.
wait_for_server() {
local host=$1
local port=$2
timeout 1200 bash -c "
until curl -s ${host}:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}

# Cleanup function
cleanup() {
echo "Caught Ctrl+C, cleaning up..."
# Cleanup commands
pgrep python | xargs kill -9 || true
# pkill -f python || true
echo "Cleanup complete. Exiting."
}


launch_pd() {
PREFILL_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
UCX_TLS=tcp \
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
VLLM_LOGGING_LEVEL=DEBUG \
VLLM_USE_V1=1 \
VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \
VLLM_NIXL_SIDE_CHANNEL_PORT=${PREFILL_NIXL_SIDE_PORT} \
PJRT_DEVICE=TPU \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
--host ${PREFILL_HOST} \
--port ${PREFILL_PORT} \
--max-model-len ${MAX_MODEL_LEN}\
--seed 42 \
--block-size ${BLOCK_SIZE} \
--enforce-eager \
--gpu-memory-utilization 0.5 \
--disable-log-requests \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"


DECODE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
UCX_TLS=tcp \
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
VLLM_LOGGING_LEVEL=DEBUG \
VLLM_USE_V1=1 \
PJRT_DEVICE=TPU \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
--host ${DECODE_HOST} \
--port ${DECODE_PORT} \
--max-model-len ${MAX_MODEL_LEN}\
--seed 42 \
--block-size ${BLOCK_SIZE} \
--enforce-eager \
--gpu-memory-utilization 0.5 \
--disable-log-requests \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"

echo ${PREFILL_BASE_CMD}
echo ${DECODE_BASE_CMD}
sleep 2

# execute on hosts
ssh -tt ${PREFILL_HOST} "${PREFILL_BASE_CMD}" &
ssh -tt ${DECODE_HOST} "${DECODE_BASE_CMD}" &
sleep 1
wait_for_server ${PREFILL_HOST} ${PREFILL_PORT}
sleep 1
wait_for_server ${DECODE_HOST} ${DECODE_PORT}
sleep 1
}

launch_pd_proxy(){
PROXY_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
python3 ${EXP_ROOT}/toy_proxy_server.py \
--prefiller-host ${PREFILL_HOST} --prefiller-port ${PREFILL_PORT} \
--decoder-host ${DECODE_HOST} --decoder-port ${DECODE_PORT} \
--host=${PROXY_HOST} --port ${PROXY_PORT}"
echo ${PROXY_BASE_CMD}
ssh -tt ${PROXY_HOST} "${PROXY_BASE_CMD}" &
}


# run disagg. & do exact-match with the outputs from baseline
launch_pd
launch_pd_proxy
sleep 10

PREFILL_HOST=${PREFILL_HOST} \
PREFILL_PORT=${PREFILL_PORT} \
DECODE_HOST=${DECODE_HOST} \
DECODE_PORT=${DECODE_PORT} \
PROXY_HOST=${PROXY_HOST} \
PROXY_PORT=${PROXY_PORT} python -m pytest -s -v ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_edge_cases.py
Loading