diff --git a/pd_xpyd/mooncake.json b/pd_xpyd/mooncake.json new file mode 100644 index 00000000000..8ea942174d8 --- /dev/null +++ b/pd_xpyd/mooncake.json @@ -0,0 +1,7 @@ +{ + "local_hostname": "localhost", + "metadata_server": "etcd://localhost:2379", + "protocol": "tcp", + "device_name": "", + "master_server_address": "localhost:50001" +} diff --git a/pd_xpyd/run_hpu_disagg_accuracy_test.sh b/pd_xpyd/run_hpu_disagg_accuracy_test.sh new file mode 100644 index 00000000000..ff19c969aa3 --- /dev/null +++ b/pd_xpyd/run_hpu_disagg_accuracy_test.sh @@ -0,0 +1,173 @@ +#!/bin/bash +set -e + +GIT_ROOT=$(git rev-parse --show-toplevel 2>/dev/null) +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'pids=$(jobs -pr); [ -n "$pids" ] && kill $pids' SIGINT SIGTERM EXIT + +# Hosts / ports +PREFILL_HOST=${PREFILL_HOST:-"localhost"} +PREFILL_PORT=${PREFILL_PORT:-8100} +DECODE_HOST=${DECODE_HOST:-"localhost"} +DECODE_PORT=${DECODE_PORT:-8200} +PROXY_HOST=${PROXY_HOST:-"localhost"} +PROXY_PORT=${PROXY_PORT:-8192} +BASELINE_HOST=${BASELINE_HOST:-"localhost"} +BASELINE_PORT=${BASELINE_PORT:-9290} + +# Model to run. +MODEL_NAME=${MODEL_NAME:-"/mnt/weka/data/pytorch/llama3/Meta-Llama-3-8B-Instruct/"} +MAX_MODEL_LEN=${MAX_MODEL_LEN:-1024} +VLLM_GPU_MEMORY_UTILIZATION=0.8 +MODEL_LEN=2048 +max_num_batched_tokens=2048 +max_num_seqs=16 + +OUTPUT_FILE="hpu_accuracy_test_outputs.txt" + +start_etcd_and_mooncake() { + etcd --listen-client-urls http://0.0.0.0:2379 --advertise-client-urls http://localhost:2379 > etcd.log 2>&1 & + mooncake_master -enable_gc true -port 50001 &> mooncake_master.log & + sleep 2 +} + +cleanup() { + echo "Cleaning up..." + sleep 2 + pkill -f etcd || true + pkill -f mooncake_master || true + pkill -f "vllm serve" || true + pkill -f "disagg_proxy_demo.py" || true + sleep 2 + echo "Cleanup complete." +} + +wait_for_server() { + local host=$1 + local port=$2 + timeout 1200 bash -c " + until curl -s ${host}:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +launch_baseline() { + BASELINE_BASE_CMD=" + HABANA_VISIBLE_DEVICES="0" \ + VLLM_USE_V1=0 \ + VLLM_SKIP_WARMUP=True \ + vllm serve $MODEL_NAME \ + --port $BASELINE_PORT \ + --seed 42 \ + --max-model-len $MODEL_LEN \ + --gpu-memory-utilization $VLLM_GPU_MEMORY_UTILIZATION \ + -tp 1 \ + --max-num-seqs $max_num_seqs \ + --trust-remote-code \ + --disable-log-requests \ + --max-num-batched-tokens $max_num_batched_tokens \ + --use-padding-aware-scheduling \ + --dtype bfloat16 \ + --enforce-eager + " + echo ${BASELINE_BASE_CMD} + bash -c "${BASELINE_BASE_CMD}" & +} + +launch_pd() { + PREFILL_BASE_CMD=" + HABANA_VISIBLE_DEVICES="0" \ + MOONCAKE_CONFIG_PATH=./mooncake.json \ + VLLM_USE_V1=0 \ + VLLM_SKIP_WARMUP=True \ + vllm serve $MODEL_NAME \ + --port 8100 \ + --seed 42 \ + --max-model-len $MODEL_LEN \ + --gpu-memory-utilization $VLLM_GPU_MEMORY_UTILIZATION \ + -tp 1 \ + --max-num-seqs $max_num_seqs \ + --trust-remote-code \ + --disable-log-requests \ + --max-num-batched-tokens $max_num_batched_tokens \ + --use-padding-aware-scheduling \ + --dtype bfloat16 \ + --enforce-eager \ + --kv-transfer-config '{\"kv_connector\":\"MooncakeStoreConnector\",\"kv_role\":\"kv_producer\"}' + " + + + DECODE_BASE_CMD=" + HABANA_VISIBLE_DEVICES="1" \ + MOONCAKE_CONFIG_PATH=./mooncake.json \ + VLLM_USE_V1=0 \ + VLLM_SKIP_WARMUP=True \ + vllm serve $MODEL_NAME \ + --port 8200 \ + --seed 42 \ + --max-model-len $MODEL_LEN \ + --gpu-memory-utilization $VLLM_GPU_MEMORY_UTILIZATION \ + -tp 1 \ + --max-num-seqs $max_num_seqs \ + --trust-remote-code \ + --disable-log-requests \ + --max-num-batched-tokens $max_num_batched_tokens \ + --use-padding-aware-scheduling \ + --dtype bfloat16 \ + --enforce-eager \ + --kv-transfer-config '{\"kv_connector\":\"MooncakeStoreConnector\",\"kv_role\":\"kv_consumer\"}' + " + + echo ${PREFILL_BASE_CMD} + echo ${DECODE_BASE_CMD} + sleep 2 + + # execute on hosts + bash -c "${PREFILL_BASE_CMD}" & + bash -c "${DECODE_BASE_CMD}" & + sleep 20 + wait_for_server ${PREFILL_HOST} ${PREFILL_PORT} + sleep 1 + wait_for_server ${DECODE_HOST} ${DECODE_PORT} + sleep 1 +} + +launch_pd_proxy(){ + PROXY_BASE_CMD=" + python3 ${GIT_ROOT}/examples/online_serving/disagg_examples/disagg_proxy_demo.py \ + --model $MODEL_NAME \ + --prefill localhost:8100 \ + --decode localhost:8200 \ + --port $PROXY_PORT" + echo ${PROXY_BASE_CMD} + bash -c "${PROXY_BASE_CMD}" & +} + +run_tests(){ + local service_url=$1 + local mode=$2 + python3 test_disagg_accuracy.py --service_url=${service_url} --model_name=$MODEL_NAME --mode=${mode} --file_name=${OUTPUT_FILE} +} + + +# run non-disagg. baseline & save outputs +launch_baseline +sleep 10 +wait_for_server ${BASELINE_HOST} ${BASELINE_PORT} +run_tests "http://${BASELINE_HOST}:${BASELINE_PORT}" "baseline" +cleanup +sleep 10 + + +# run disagg. & do exact-match with the outputs from baseline +start_etcd_and_mooncake +launch_pd +launch_pd_proxy +sleep 10 +run_tests "http://${PROXY_HOST}:${PROXY_PORT}" "disagg" +echo "-----P/D success----" + +rm ${OUTPUT_FILE} +cleanup + +exit 0 diff --git a/pd_xpyd/test_disagg_accuracy.py b/pd_xpyd/test_disagg_accuracy.py new file mode 100644 index 00000000000..eeca6db1766 --- /dev/null +++ b/pd_xpyd/test_disagg_accuracy.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: Apache-2.0 +import argparse +import json +import time + +import openai +import requests + +MAX_OUTPUT_LEN = 30 + +SAMPLE_PROMPTS = ( + "Red Hat is the best company in the world to work for because it works on " + "open source software, which means that all the contributions are " + "delivered to the community. As a result, when working on projects like " + "vLLM we are able to meet many amazing people from various organizations " + "like AMD, Google, NVIDIA, ", + "We hold these truths to be self-evident, that all men are created equal, " + "that they are endowed by their Creator with certain unalienable Rights, " + "that among these are Life, Liberty and the pursuit of Happiness.--That " + "to secure these rights, Governments are instituted among Men, deriving " + "their just powers from the consent of the governed, ", +) + + +def check_vllm_server(url: str, timeout=5, retries=3) -> bool: + """ + Checks if the vLLM server is ready by sending a GET request to the + /health endpoint. + + Args: + url (str): The base URL of the vLLM server. + timeout (int): Timeout in seconds for the request. + retries (int): Number of retries if the server is not ready. + + Returns: + bool: True if the server is ready, False otherwise. + """ + for attempt in range(retries): + try: + response = requests.get(url, timeout=timeout) + if response.status_code == 200: + return True + else: + print(f"Attempt {attempt + 1}: Server returned status code " + "{response.status_code}") + except requests.exceptions.RequestException as e: + print(f"Attempt {attempt + 1}: Error connecting to server: {e}") + time.sleep(1) # Wait before retrying + return False + + +def run_simple_prompt(base_url: str, model_name: str, + input_prompt: str) -> str: + client = openai.OpenAI(api_key="EMPTY", base_url=base_url) + completion = client.completions.create(model=model_name, + prompt=input_prompt, + max_tokens=MAX_OUTPUT_LEN, + temperature=0.0, + seed=42) + + return completion.choices[0].text + + +def main(): + """ + This script demonstrates how to accept two optional string arguments + ("service_url" and "file_name") from the command line, each with a + default value of an empty string, using the argparse module. + """ + parser = argparse.ArgumentParser(description="vLLM client script") + + parser.add_argument( + "--service_url", # Name of the first argument + type=str, + required=True, + help="The vLLM service URL.") + + parser.add_argument( + "--model_name", # Name of the first argument + type=str, + required=True, + help="model_name", + ) + + parser.add_argument( + "--mode", # Name of the second argument + type=str, + default="baseline", + help="mode: baseline==non-disagg, or disagg", + ) + + parser.add_argument( + "--file_name", # Name of the second argument + type=str, + default=".vllm_output.txt", + help="the file that saves the output tokens ", + ) + + args = parser.parse_args() + + for arg in vars(args): + print(f"{arg}: {getattr(args, arg)}") + + service_url = f"{args.service_url}/v1" + + output_strs = dict() + for prompt in SAMPLE_PROMPTS: + output_str = run_simple_prompt(base_url=service_url, + model_name=args.model_name, + input_prompt=prompt) + print(f"Prompt: {prompt}, output: {output_str}") + output_strs[prompt] = output_str + + if args.mode == "baseline": + # baseline: save outputs + try: + with open(args.file_name, 'w') as json_file: + json.dump(output_strs, json_file, indent=4) + except OSError as e: + print(f"Error writing to file: {e}") + raise + else: + # disagg. verify outputs + baseline_outputs = None + try: + with open(args.file_name) as json_file: + baseline_outputs = json.load(json_file) + except OSError as e: + print(f"Error writing to file: {e}") + raise + assert isinstance(baseline_outputs, dict) + assert len(baseline_outputs) == len(output_strs) + for prompt, output in baseline_outputs.items(): + assert prompt in output_strs, f"{prompt} not included" + assert output == output_strs[prompt], ( + f"baseline_output: {output} != PD output: {output_strs[prompt]}" + ) + + +if __name__ == "__main__": + main()