diff --git a/examples/online_serving/dllm_tool/dllm/__init__.py b/examples/online_serving/dllm_tool/dllm/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/online_serving/dllm_tool/dllm/balancer.py b/examples/online_serving/dllm_tool/dllm/balancer.py new file mode 100644 index 00000000000..bb726905615 --- /dev/null +++ b/examples/online_serving/dllm_tool/dllm/balancer.py @@ -0,0 +1,244 @@ +import asyncio +import logging +import time +from typing import Dict, List, Optional + +import aiohttp +import ray +from prometheus_client.parser import text_string_to_metric_families +from ray import actor + +from dllm import constants +from dllm.constants import CONTROLLER_ACTOR_NAME, DLLM_NAMESPACE +from dllm.entities import (DispatchResult, MetricsInfo, Role, SchedulerPolicy, + VllmInstanceInfo, VllmInstanceStatus) + +logger = logging.getLogger(__name__) + + +class Balancer: + + def __init__( + self, + policy: SchedulerPolicy = SchedulerPolicy.ROUND_ROBIN, + ): + self.policy = policy + self.role_2_instances: Dict[Role, List[VllmInstanceInfo]] = { + } # prefill/decode/mixed => VllmInstanceInfo + self.instance_infos: Dict[str, VllmInstanceInfo] = { + } # id -> VllmInstanceInfo + self.instance_metrics: Dict[str, MetricsInfo] = {} # id -> MetricsInfo + self._round_robin_index_p = 0 + self._round_robin_index_d = 0 + self._round_robin_index_m = 0 + self.last_heartbeat: Dict[str, float] = {} + self._controller_handle: Optional[actor.ActorHandle] = None + self.all_instances_ready = False + # start update metrics loop + loop = asyncio.get_event_loop() + loop.create_task(self.update_vllm_instance_metrics()) + + async def update_vllm_instance_metrics(self): + while True: + try: + async with aiohttp.ClientSession() as session: + await asyncio.gather( + *[ + self._query_instance_metrics( + session, instance_info) + for instance_info in self.instance_infos.values() + if instance_info.uri is not None + ], + return_exceptions=True, + ) + await asyncio.sleep(constants.METRICS_UPDATE_CYCLE) + except Exception as e: + logger.error("create request session error: %s", e) + + def dispatch_request(self) -> DispatchResult: + if self.policy == SchedulerPolicy.ROUND_ROBIN: + return self._round_robin_pair() + else: + raise ValueError(f"Unsupported policy: {self.policy}") + + def get_all_instance(self) -> Dict[str, VllmInstanceInfo]: + '''Return all vllm instance.''' + return self.instance_infos + + async def _query_instance_metrics(self, session, instance_info): + ins_uri = instance_info.uri + ins_id = instance_info.id + async with session.post(f"{ins_uri}/metrics", timeout=3) as resp: + resp_code = resp.status + if resp_code != constants.HTTP_OK: + logger.error( + f"get metrics failed, uri:{ins_uri}, code:{resp_code}") + return + resp_body = await resp.text() + # {metric_name: metric_value} + metrics_dict = { + metric_family.name: metric_family.samples[0].value + for metric_family in text_string_to_metric_families(resp_body) + if metric_family.name in MetricsInfo.METRIC_NAME_MAPPING. + values() and metric_family.samples + } + if not metrics_dict: + return + if ins_id not in self.instance_metrics: + self.instance_metrics[ins_id] = MetricsInfo() + metric_info = self.instance_metrics[ins_id] + for param_name, metric_name in MetricsInfo.METRIC_NAME_MAPPING.items( + ): + if metric_name not in metrics_dict: + continue + # data type conversion + target_type = metric_info.__annotations__[param_name] + setattr(metric_info, param_name, + target_type(metrics_dict[metric_name])) + logger.debug("instance metrics info: %s", self.instance_metrics) + + def _round_robin_pair(self) -> DispatchResult: + # current policy: if has mixed, use mixed + is_pd_disagged = Role.MIXED not in self.role_2_instances or len( + self.role_2_instances[Role.MIXED]) == 0 + if not is_pd_disagged: + mixed_uri = self._round_robin_selection(Role.MIXED) + return DispatchResult(prefill_uri=None, decode_uri=mixed_uri) + + prefill_uri = self._round_robin_selection(Role.PREFILL) + decode_uri = self._round_robin_selection(Role.DECODE) + return DispatchResult(prefill_uri=prefill_uri, decode_uri=decode_uri) + + def _round_robin_selection(self, role: Role) -> str: + instances = [ + item.uri for i, item in self.instance_infos.items() + if item.role == role and item.uri is not None + ] + if role == Role.PREFILL: + instance = instances[self._round_robin_index_p] + self._round_robin_index_p = (self._round_robin_index_p + + 1) % len(instances) + if role == Role.DECODE: + instance = instances[self._round_robin_index_d] + self._round_robin_index_d = (self._round_robin_index_d + + 1) % len(instances) + if role == Role.MIXED: + instance = instances[self._round_robin_index_m] + self._round_robin_index_m = (self._round_robin_index_m + + 1) % len(instances) + return instance + + def update_vllm_instance_info(self, infos: List[VllmInstanceInfo]): + for item in infos: + self.instance_infos[item.id] = item + self.instance_metrics[item.id] = MetricsInfo() + + # reconstruct the role map + self.role_2_instances.clear() + for _, instance_info in self.instance_infos.items(): + if instance_info.role not in self.role_2_instances: + self.role_2_instances[instance_info.role] = [] + self.role_2_instances[instance_info.role].append(instance_info) + + async def update_vllm_instance_health( + self, vllm_instance_info: List[VllmInstanceInfo]) -> bool: + """ + Update health status of VLLM instances. + + Args: + vllm_instance_info: List of VllmInstanceInfo objects containing information + + Returns: + bool: True if update was successful + """ + + current_time = time.time() + for item in vllm_instance_info: + self.instance_infos[item.id] = item + self.last_heartbeat[item.id] = current_time + return True + + async def is_all_instances_ready(self): + """ + Wait until all VLLM actor running status + + Returns: + No return value. End of function when all actor ready,. + """ + if not self._controller_handle: + try: + self._controller_handle = ray.get_actor( + name=CONTROLLER_ACTOR_NAME, namespace=DLLM_NAMESPACE) + except BaseException: + logger.error('get _controller_handle fail') + _get_expected_vllm_actors_num = await self._controller_handle._get_expected_vllm_actors_num.remote( # type: ignore # ray remote call + ) + while self._get_ready_vllm_actors_num( + ) < _get_expected_vllm_actors_num: + try: + logger.debug( + f"expect {self._get_ready_vllm_actors_num()} waiting vllm actor, " + f"{self.instance_infos}") + for s in self.instance_infos.values(): + if s.status == VllmInstanceStatus.SUBPROCESS_EXITED: + raise RuntimeError( + f"vllm instance: {s} exited unexpectedly") + await asyncio.sleep(1) + except Exception as e: + logger.error( + f"An error when waiting vllm instances ready: {e}") + return + logger.info("All actors are already") + self.all_instances_ready = True + asyncio.create_task(self._monitor_instance_health()) + + def _get_ready_vllm_actors_num(self): + """ + Get the number of ready VLLM instances. + + Returns: + Number of ready VLLM instances. + """ + return sum(info.status == VllmInstanceStatus.RUNNING + for info in self.instance_infos.values()) + + def _get_unready_vllm_actors_num(self): + """ + Get the number of unready VLLM instances. + + Returns: + Number of unready VLLM instances. + """ + return sum(info.status != VllmInstanceStatus.RUNNING + for info in self.instance_infos.values()) + + async def _monitor_instance_health(self): + """ + Monitor instance health, report to controller if >20s no response / failed status + """ + while True: + if self.all_instances_ready: + current_time = time.time() + for info in self.instance_infos.values(): + logger.info( + f"Monitoring ID: {info.id}, Status: {info.status}") + if info.status == VllmInstanceStatus.HEALTHCHECK_FAILED: + logger.error( + f"Instance {info.id} has failed health check.") + self._controller_handle.report_failure_from_balancer.remote( # type: ignore # ray remote call + info.id) + self.all_instances_ready = False + # Consider instance unhealthy if no heartbeat + elif current_time - self.last_heartbeat.get(info.id, + 0) > 20: + logger.error( + f"Instance {info.id} is unhealthy (no heartbeat).") + self._controller_handle.report_failure_from_balancer.remote( # type: ignore # ray remote call + info.id) + self.all_instances_ready = False + else: + logger.info( + "Waiting for all instances ready and restart the health monitoring." + ) + await asyncio.sleep(5) + await asyncio.sleep(1) diff --git a/examples/online_serving/dllm_tool/dllm/config.py b/examples/online_serving/dllm_tool/dllm/config.py new file mode 100644 index 00000000000..0ba29373d4d --- /dev/null +++ b/examples/online_serving/dllm_tool/dllm/config.py @@ -0,0 +1,142 @@ +from dataclasses import dataclass +from typing import List, Optional + +from dllm.entities import Role, SchedulerPolicy + +required_vllm_options = [ + ("host", ), + ("port", ), + ("tensor-parallel-size", "tp"), + ("data-parallel-size", "dp"), + ("data-parallel-size-local", "dpl"), + ("data-parallel-start-rank", "dpr"), + ("data-parallel-address", "dpa"), + ("data-parallel-rpc-port", "dpp"), + ("headless", ), + ("enable-expert-parallel", ), + ("disable-expert-parallel", ), + ("kv-transfer-config", ), +] + + +class AutoValidator: + + def __post_init__(self): + for name, f in self.__dataclass_fields__.items( # type: ignore # get all data fields from a dataclass + ): + if method := getattr(self, f"_validate_{name}", None): + method() + + +@dataclass +class InferenceInstanceConfig(AutoValidator): + startup_params: List[str] + startup_env: Optional[str] + tp: int + dp: int + ep: int + + def _validate_startup_params(self): + + def __contain_long_options(opname, params): + underline_op = opname.replace("-", "_") + return any( + p == f"--{opname}" or p.startswith(f"--{opname}=") or p == + f"--{underline_op}" or p.startswith(f"--{underline_op}=") + for p in params) + + def __contain_short_options(opname, params): + underline_op = opname.replace("-", "_") + return any( + p == f"-{opname}" or p.startswith(f"-{opname}=") + or p == f"-{underline_op}" or p.startswith(f"-{underline_op}=") + for p in params) + + bad_options = [] + for option in required_vllm_options: + if len(option) > 0 and __contain_long_options( + option[0], self.startup_params): + bad_options.append(option[0]) + if len(option) > 1 and __contain_short_options( + option[1], self.startup_params): + bad_options.append(option[1]) + + if bad_options: + raise ValueError( + f"{bad_options} should not be specified in start up commands, instead, dllm will populate options after verification" + ) + + def _validate_ep(self): + if self.ep < 0: # type: ignore + raise ValueError( + "expert parallel size should be 0 (EP disabled) or >1 (EP enabled)" + ) + + def _validate_dp(self): + if not self.dp > 0: # type: ignore + raise ValueError("data parallel size should be greater than 0") + + def _validate_tp(self): + if not self.tp > 0: # type: ignore + raise ValueError("tensor parallel size should be greater than 0") + + +@dataclass +class ControllerConfig(AutoValidator): + scheduler_policy: SchedulerPolicy + num_prefill_instances: int + num_decode_instances: int + prefill_instance_config: InferenceInstanceConfig + decode_instance_config: InferenceInstanceConfig + + def _validate_num_prefill_instances(self): + if self.num_prefill_instances < 0: + raise ValueError( + "number of prefill instances should be equal to or greater than 0" + ) + + def _validate_num_decode_instances(self): + if self.num_decode_instances < 0: + raise ValueError( + "number of decode instances should be equal to or greater than 0" + ) + + +@dataclass +class PDDistConfig(AutoValidator): + role: Role + pd_rank: int = 0 + pd_size: int = 0 + + def is_pd_dist(self): + return self.role != Role.MIXED + + +@dataclass +class DPConfig(AutoValidator): + dp_rank: int = 0 + dp_size: int = 1 + dp_local_size: int = 1 + dp_master_ip: str = "" + dp_master_port: int = 0 + + def is_dp_enabled(self): + return self.dp_size and self.dp_size > 1 + + +@dataclass +class EPConfig(AutoValidator): + ep_size: int = 0 + + def is_ep_enabled(self): + return self.ep_size and self.ep_size > 0 + + +@dataclass +class VllmInstanceConfig(AutoValidator): + exec_cmd: list[str] + env: Optional[str] = None + tp: int = 1 + dp_config: Optional[DPConfig] = None + ep_config: Optional[EPConfig] = None + pd_config: Optional[PDDistConfig] = None diff --git a/examples/online_serving/dllm_tool/dllm/constants.py b/examples/online_serving/dllm_tool/dllm/constants.py new file mode 100644 index 00000000000..8fde92e270e --- /dev/null +++ b/examples/online_serving/dllm_tool/dllm/constants.py @@ -0,0 +1,24 @@ +DLLM_NAMESPACE = "dllm" +CONTROLLER_ACTOR_NAME = "controller" +BALANCER_ACTOR_NAME = "balancer" + +ENDPOINT_APPLICATION_NAME = "dllm-endpoint" +ENDPOINT_PROXY_DEPLOYMENT_NAME = "dllm-endpoint" + +INSTANCE_HEALTHCHECK_INTERVAL_SEC = 10 + +HTTP_OK = 200 + +HTTP_PARAM_INVALID = 400 + +HTTP_TOO_MANY_REQUESTS = 429 + +HTTP_INTERNAL_ERROR = 500 + +NUM_RUNNING_REQUESTS = "vllm:num_requests_running" + +NUM_WAITING_REQUESTS = "vllm:nun_requests_waiting" + +DEVICE_USAGE_PERCENT = "vllm:gpu_cache_usage_perc" + +METRICS_UPDATE_CYCLE = 0.5 diff --git a/examples/online_serving/dllm_tool/dllm/controller.py b/examples/online_serving/dllm_tool/dllm/controller.py new file mode 100644 index 00000000000..a0957aff2d5 --- /dev/null +++ b/examples/online_serving/dllm_tool/dllm/controller.py @@ -0,0 +1,291 @@ +import asyncio +import itertools +import logging +from typing import Dict, List, Optional + +import ray +from ray import actor + +from dllm.balancer import Balancer +from dllm.config import (ControllerConfig, DPConfig, EPConfig, + InferenceInstanceConfig, PDDistConfig, + VllmInstanceConfig) +from dllm.constants import BALANCER_ACTOR_NAME +from dllm.entities import Role, VllmInstanceInfo +from dllm.vllm_instance import start_vllm_instance + +from vllm.platforms import current_platform + +logger = logging.getLogger(__name__) + + +def flatten_list(multi_level_list): + return list(itertools.chain(*multi_level_list)) + + +def _get_accelerator_num_per_ray_node(): + accelerator_nums = [] + for e in ray.nodes(): + num = e.get("Resources", {}).get(current_platform.device_name, None) + if num: + accelerator_nums.append(int(num)) + return max(accelerator_nums) + + +def split_dp_resources(tp_size: int, + dp_size: int, + accelerators_pack_max_size: int = 8) -> List[int]: + """ + pack DP instances into nodes, prevent cross-node DP instance at best effort + | DP | TP | total | 910C | 910B | + | ------ | ------ | ------ | ------ | ------ | + | 4 | 2 | 8 | 8 | 8 | + | 3 | 3 | 9 | 9 | 6+3 | + | 4 | 4 | 16 | 16 | 8+8 | + | 32 | 1 | 32 | 16+16 | 8x4 | + | 64 | 1 | 64 | 16x4 | 8x8 | + + TODO: optimize resource fragments + """ + assert tp_size <= accelerators_pack_max_size, ( + f"do not allow TP size to exceed the number of accelerators on a single node {accelerators_pack_max_size}" + ) + total_accelerators = dp_size * tp_size + group_size = (accelerators_pack_max_size - (accelerators_pack_max_size % tp_size) + if accelerators_pack_max_size % tp_size != 0 else accelerators_pack_max_size) + num_groups = total_accelerators // group_size + remainder = total_accelerators % group_size + packs = [group_size * num_groups] + if remainder > 0: + packs.append(remainder) + return packs + + +async def make_dp_group(pd_role: Role, + pd_idx: int, + tp_size: int, + dp_size: int, + ep_size: int, + start_params: List[str], + env: Optional[str] = None) -> List[actor.ActorHandle]: + """ + prepare one DP group + 1. start DP master vllm instance + 1.1. find DP master ip and a free port as DP master port + 1.2. init DP master vllm instance's DP config + 2. start all other DP instances and pass through DP master ip and port + """ + packs = split_dp_resources(tp_size=tp_size, + dp_size=dp_size, + accelerator_pack_max_size=_get_accelerator_num_per_ray_node()) + pg = ray.util.placement_group(bundles=[{ + current_platform.device_name: p + } for p in packs], + strategy="PACK", + name=f"DP-{pd_role}-{pd_idx}") + await pg.ready() + + actors = [] + dp_master_vllm_instance_config = VllmInstanceConfig( + exec_cmd=start_params, + env=env, + tp=tp_size, + pd_config=PDDistConfig(role=pd_role, pd_rank=pd_idx), + dp_config=DPConfig(dp_rank=0, + dp_size=dp_size, + dp_local_size=packs[0] // tp_size, + dp_master_ip="", + dp_master_port=0), + ep_config=EPConfig(ep_size=ep_size), + ) + dp_master_actor = start_vllm_instance( + vllm_instance_config=dp_master_vllm_instance_config, pg=pg) + actors.append(dp_master_actor) + + dp_master_ip, dp_master_port = await dp_master_actor.init_dp_master_ip_port.remote( + ) # type: ignore # ray remote call + dp_master_vllm_instance_config.dp_config.dp_master_ip = dp_master_ip # type: ignore + dp_master_vllm_instance_config.dp_config.dp_master_port = dp_master_port # type: ignore + await dp_master_actor.init_dp_config.remote( + dp_master_vllm_instance_config.dp_config + ) # type: ignore # ray remote call + + dp_rank = packs[0] // tp_size + for idx in range(1, len(packs)): + dp_vllm_instance_config = VllmInstanceConfig( + exec_cmd=start_params, + env=env, + tp=tp_size, + pd_config=PDDistConfig(role=pd_role, pd_rank=pd_idx), + dp_config=DPConfig( + dp_rank=dp_rank, + dp_size=dp_size, + dp_master_ip=dp_master_ip, + dp_master_port=dp_master_port, + dp_local_size=packs[idx] // tp_size, + ), + ep_config=EPConfig(ep_size=ep_size), + ) + dp_rank += packs[idx] // tp_size + actor = start_vllm_instance( + vllm_instance_config=dp_vllm_instance_config, pg=pg) + await actor.init_dp_config.remote(dp_vllm_instance_config.dp_config + ) # type: ignore # ray remote call + actors.append(actor) + return actors + + +class Controller: + + def __init__(self, controller_config: ControllerConfig): + """ + Initialize the global controller. + + Args: + controller_config: ControllerConfig + """ + self.config = controller_config + + self.p_instances_actors: List[actor.ActorHandle] = [] + self.d_instances_actors: List[actor.ActorHandle] = [] + self.vllm_instances_info: Dict[str, VllmInstanceInfo] = {} + self.balancer = None + + def _get_expected_vllm_actors_num(self): + return len(self.p_instances_actors) + len(self.d_instances_actors) + + async def make_inference_instance( + self, pd_role: Role, pd_rank: int, + inference_instance_config: InferenceInstanceConfig + ) -> List[actor.ActorHandle]: + """make inference instance (PREFILL instance, or DECODE instance) + 1. if dp enabled, ==> start dp group + 2. if dp not enabled, ==> just start vllm instance + + Returns: + all vllm instances actors in this inference instance + """ + if inference_instance_config.dp and inference_instance_config.dp > 1: + if not inference_instance_config.tp: + inference_instance_config.tp = 1 + # enable dp + return await make_dp_group( + pd_role=pd_role, + pd_idx=pd_rank, + tp_size=inference_instance_config.tp, + dp_size=inference_instance_config.dp, + ep_size=inference_instance_config.ep, + start_params=inference_instance_config.startup_params, + env=inference_instance_config.startup_env, + ) + + # no dp + return [ + start_vllm_instance( + VllmInstanceConfig( + exec_cmd=inference_instance_config.startup_params, + env=inference_instance_config.startup_env, + tp=inference_instance_config.tp, + pd_config=PDDistConfig(role=pd_role, pd_rank=pd_rank), + dp_config=DPConfig(), + ep_config=EPConfig(inference_instance_config.ep), + )) + ] + + async def make_balancer(self) -> List[actor.ActorHandle]: + """make balancer, and send all vllm instance info to the balancer + + Returns: + balancer handle + """ + balancer = ray.remote(Balancer).options( + name=BALANCER_ACTOR_NAME).remote( + self.config.scheduler_policy) # type: ignore # ray remote call + return balancer + + async def initialize(self): + """initialize all vllm instances, construct pd/dp groups""" + # TODO: Need to implement the resource checking logic with Ray + logger.info(f"initialize with config: {self.config}") + # Dictionary to track VLLM instances health status + self.vllm_instances_info: Dict[str, VllmInstanceInfo] = {} # + + # start VllmInstance + # start Prefill Instances + is_disaggregated_pd = self.config.num_prefill_instances > 0 and self.config.num_decode_instances > 0 + for p_pd_rank in range(self.config.num_prefill_instances): + p_actors = self.make_inference_instance( + pd_rank=p_pd_rank, + pd_role=Role.PREFILL if is_disaggregated_pd else Role.MIXED, + inference_instance_config=self.config.prefill_instance_config, + ) + self.p_instances_actors.extend(await p_actors) + + # start Decode Instances + for d_pd_rank in range(self.config.num_decode_instances): + d_actors = self.make_inference_instance( + pd_rank=d_pd_rank, + pd_role=Role.DECODE if is_disaggregated_pd else Role.MIXED, + inference_instance_config=self.config.decode_instance_config, + ) + self.d_instances_actors.extend(await d_actors) + + logger.info("Create Balancer") + self.balancer = await self.make_balancer() + + # init all vllm instances + # TODO how to handle restart for reliability issues + for vllm_instance_actor in [ + *self.p_instances_actors, *self.d_instances_actors + ]: + vllm_instance_actor.initialize.remote() + + # wait for all instances ready + await self.balancer.is_all_instances_ready.remote( # type: ignore # ray remote call + ) + + logger.info( + f"All instances ready, VllmInstance num: {len(self.vllm_instances_info)}, updating Balancer" + ) + + # update Balancer + self.balancer.update_vllm_instance_info.remote( # type: ignore # ray remote call + list(self.vllm_instances_info.values())) + + # TODO start Endpoint actors on each node:: deploy_endpoint_to_cluster() + logger.info( + f"Controller initialized with {self.config.num_prefill_instances} P instances and " + f"{self.config.num_decode_instances} D instances") + + async def terminate(self, timeout_s=5): + """ + TODO: clean all dllm actors started by controller + """ + if self.balancer: + ray.kill(self.balancer) + + terminate_futures = [] + for instance_actor in [ + *self.p_instances_actors, *self.d_instances_actors + ]: + terminate_futures.append( + instance_actor.terminate.remote(timeout_s=timeout_s)) + await asyncio.gather(*terminate_futures) + + for instance_actor in [ + *self.p_instances_actors, *self.d_instances_actors + ]: + ray.kill(instance_actor) + + # TODO: Need to implement method to monitor and restart failed instances + + def report_failure_from_balancer(self, instance_id): + """ + Report fail instance from balancer + + Returns: + No Return required + """ + logger.info( + f"Received report from balancer, instance_id is {instance_id} ") + return True diff --git a/examples/online_serving/dllm_tool/dllm/endpoint.py b/examples/online_serving/dllm_tool/dllm/endpoint.py new file mode 100644 index 00000000000..9576fe03522 --- /dev/null +++ b/examples/online_serving/dllm_tool/dllm/endpoint.py @@ -0,0 +1,163 @@ +import logging +import os +import uuid +from typing import Union + +import aiohttp +import ray +from fastapi import FastAPI, Request +from fastapi.responses import Response, StreamingResponse +from ray import actor, serve + +from dllm.constants import (BALANCER_ACTOR_NAME, DLLM_NAMESPACE, + ENDPOINT_APPLICATION_NAME, + ENDPOINT_PROXY_DEPLOYMENT_NAME) +from dllm.entities import DispatchResult + +logger = logging.getLogger(__name__) + +app = FastAPI() + + +@serve.deployment( + name=ENDPOINT_PROXY_DEPLOYMENT_NAME, + num_replicas=1, + max_ongoing_requests=4096, +) +@serve.ingress(app) +class ProxyDeployment: + #: the balancer handle + _balancer_handle: Union[actor.ActorHandle, None] + + def __init__(self): + self._balancer_handle = None + + @staticmethod + async def record_exception_info(e): + """ + record exception info + Args: + e: exception info + """ + import sys + import traceback + exc_info = sys.exc_info() + logger.info("Error occurred in disagg prefill proxy server") + logger.info(e) + logger.info("".join(traceback.format_exception(*exc_info))) + + async def forward_request(self, url: str, headers: dict, data: dict): + """ + Send request to the inference instance, return the AsyncGenerator reading the content + Args: + url: request url + headers: request header + data: request data + Returns: + AsyncGenerator: the first iteration is the status code, and subsequent iterations are the response content + """ + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout( + total=6 * 60 * 60)) as session: + async with session.post(url=url, json=data, + headers=headers) as response: + # Return status code in advance + yield response.status + if response.status == 200: + async for chunk_bytes in response.content.iter_chunked( + 1024): + yield chunk_bytes + else: + content = await response.read() + yield content + + async def forward_request_without_yield(self, url: str, headers: dict, + data: dict): + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout( + total=6 * 60 * 60)) as session: + async with session.post(url=url, json=data, + headers=headers) as response: + content = await response.read() + return response.status, content + + async def schedule(self, prompt: str) -> DispatchResult: + if self._balancer_handle is None: + self._balancer_handle = ray.get_actor(name=BALANCER_ACTOR_NAME, + namespace=DLLM_NAMESPACE) + dispatch_result = await self._balancer_handle.dispatch_request.remote( # type: ignore # ray remote call + ) + return dispatch_result + + @app.post("/health") + async def health(self, request: Request): + return Response(status_code=200, content="healthy") + + @app.post("/v1/completions") + async def openai_completions(self, raw_request: Request): + """ + https://github.com/vllm-project/vllm/blob/main/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py + """ + import pydantic + from vllm.entrypoints.openai.protocol import CompletionRequest + + request_body = await raw_request.json() + headers = { + "Authorization": + f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": + raw_request.headers.get("X-Request-Id") or str(uuid.uuid4()) + } + + try: + request = CompletionRequest(**request_body) + except pydantic.ValidationError as e: + return Response(status_code=500, content={"error": str(e)}) + + assert isinstance(request.prompt, + str), "currently only support one prompt at a time" + + dispatch_result = await self.schedule(request.prompt) + logger.info( + f"({headers['X-Request-Id']}) recv request: {request.prompt}, " + f"prefill to: {dispatch_result.prefill_uri}," + f"decode to {dispatch_result.decode_uri}") + + try: + prefill_request = request_body.copy() + prefill_request["max_tokens"] = 1 + if dispatch_result.prefill_uri: + status_code, prefill_result = await self.forward_request_without_yield( + f"{dispatch_result.prefill_uri}/v1/completions", + headers=headers, + data=prefill_request, + ) + if status_code != 200: + logger.error( + f"prefill request failed, status code:{status_code}, content:{prefill_result}" + ) + return Response(content=prefill_result, + status_code=status_code) + + # return decode + decode_token_generator = self.forward_request( + f"{dispatch_result.decode_uri}/v1/completions", + headers=headers, + data=request_body, + ) + status_code = 200 + # Only iterate once, get the status code and transmit it transparently + async for status in decode_token_generator: + status_code = status + break + return StreamingResponse( + decode_token_generator, # type: ignore + status_code=status_code, # type: ignore + media_type="application/octet-stream", + ) + except Exception as e: + await self.record_exception_info(e) + raise + + +def deploy_endpoint_to_cluster(host: str = "0.0.0.0", port: int = 8000): + serve.start(http_options=serve.HTTPOptions(host=host, port=port)) + serve.run(ProxyDeployment.bind(), name=ENDPOINT_APPLICATION_NAME) diff --git a/examples/online_serving/dllm_tool/dllm/entities.py b/examples/online_serving/dllm_tool/dllm/entities.py new file mode 100644 index 00000000000..26c6d129a49 --- /dev/null +++ b/examples/online_serving/dllm_tool/dllm/entities.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass +from enum import Enum, auto +from typing import ClassVar, Optional + +from dllm import constants + + +class VllmInstanceStatus(Enum): + UNREADY = auto() + RUNNING = auto() + SUBPROCESS_EXITED = auto() + HEALTHCHECK_FAILED = auto() + + +class Role(Enum): + PREFILL = 0 + DECODE = 1 + MIXED = 2 + + +class SchedulerPolicy(Enum): + ROUND_ROBIN = 0 + + +@dataclass +class VllmInstanceInfo: + id: str + uri: str + role: Role + status: VllmInstanceStatus = VllmInstanceStatus.UNREADY + dp_master_ip: str = "" + dp_master_port: int = 0 + + +@dataclass +class DispatchResult: + prefill_uri: Optional[str] + decode_uri: Optional[str] + + +@dataclass +class MetricsInfo: + num_running_requests: int = 0 + num_waiting_requests: int = 0 + device_usage_percent: float = 0.0 + + METRIC_NAME_MAPPING: ClassVar[dict] = { + "num_running_requests": constants.NUM_RUNNING_REQUESTS, + "num_waiting_requests": constants.NUM_WAITING_REQUESTS, + "device_usage_percent": constants.DEVICE_USAGE_PERCENT, + } \ No newline at end of file diff --git a/examples/online_serving/dllm_tool/dllm/logging.py b/examples/online_serving/dllm_tool/dllm/logging.py new file mode 100644 index 00000000000..073c2bbd0fe --- /dev/null +++ b/examples/online_serving/dllm_tool/dllm/logging.py @@ -0,0 +1,11 @@ +import logging + +def setup_logging(level=logging.INFO): + logger = logging.getLogger("dllm") + logger.propagate = False + if not logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter("[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)s] %(message)s") + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.setLevel(level) diff --git a/examples/online_serving/dllm_tool/dllm/scripts.py b/examples/online_serving/dllm_tool/dllm/scripts.py new file mode 100644 index 00000000000..46f0dcedd0b --- /dev/null +++ b/examples/online_serving/dllm_tool/dllm/scripts.py @@ -0,0 +1,235 @@ +from typing import List, Optional +import click +import ray +import logging +import shlex + +from dllm.controller import Controller +from dllm.endpoint import deploy_endpoint_to_cluster +from dllm.logging import setup_logging +from dllm.constants import ENDPOINT_APPLICATION_NAME, DLLM_NAMESPACE, CONTROLLER_ACTOR_NAME +from dllm.entities import SchedulerPolicy +from dllm.config import ControllerConfig, InferenceInstanceConfig + +setup_logging() +logger = logging.getLogger(__name__) + +@click.group() +def cli(): + """DLLM Cluster Management""" + pass + + +@cli.command(name="deploy", context_settings={"show_default": True}) +@click.option("--head-ip", type=str, help='IP of Ray head node (e.g. "10.2.3.4")', default="auto") +@click.option("--prefill-instances-num", type=int, help="the num of Prefill instances", default=0) +@click.option( + "--prefill-startup-params", + type=str, + help="the Prefill instance start up command", + default="vllm serve /workspace/models/qwen2.5_7B", + callback=lambda ctx, param, value: shlex.split(value), +) +@click.option( + "--prefill-startup-env", + type=str, + help="the Prefill instance start up env", + default=None, +) +@click.option("--prefill-data-parallel-size", "-pdp", type=int, help="the dp of Prefill instances", default=1) +@click.option("--prefill-tensor-parallel-size", "-ptp", type=int, help="the tp of Prefill instances", default=1) +@click.option( + "--prefill-expert-parallel-size", + "-pep", + type=int, + help="the ep of Prefill instances, should be equal to dp*tp, 0 means disable expert parallelism", + default=0, +) +@click.option("--decode-instances-num", type=int, help="the num of Decode instances", default=0) +@click.option( + "--decode-startup-params", + type=str, + help="the Decode instance start up command", + default="vllm serve /workspace/models/qwen2.5_7B", + callback=lambda ctx, param, value: shlex.split(value), +) +@click.option( + "--decode-startup-env", + type=str, + help="the decode instance start up env", + default=None, +) +@click.option("--decode-data-parallel-size", "-ddp", type=int, help="the dp of Decode instances", default=1) +@click.option("--decode-tensor-parallel-size", "-dtp", type=int, help="the tp of Decode instances", default=1) +@click.option( + "--decode-expert-parallel-size", + "-dep", + type=int, + help="the ep of Decode instances, should be equal to dp*tp, 0 means disable expert parallelism", + default=0, +) +@click.option( + "--scheduler-policy", + type=click.Choice([e.name for e in SchedulerPolicy], case_sensitive=False), + help="the scheduling policy, default to RoundRobin", + default=SchedulerPolicy.ROUND_ROBIN.name, + callback=lambda ctx, param, value: SchedulerPolicy[value.upper()], +) +@click.option("--proxy-host", type=str, help="the dllm service listening host", default="0.0.0.0") +@click.option("--proxy-port", type=int, help="the dllm service listening port", default=8000) +def deploy( + head_ip: str, + prefill_instances_num: int, + prefill_startup_params: List[str], + prefill_startup_env: Optional[str], + prefill_data_parallel_size: int, + prefill_tensor_parallel_size: int, + prefill_expert_parallel_size: int, + decode_instances_num: int, + decode_startup_params: List[str], + decode_startup_env: Optional[str], + decode_data_parallel_size: int, + decode_tensor_parallel_size: int, + decode_expert_parallel_size: int, + scheduler_policy: SchedulerPolicy, + proxy_host: str, + proxy_port: int, +): + _inner_deploy( + head_ip, + prefill_instances_num, + prefill_startup_params, + prefill_startup_env, + prefill_data_parallel_size, + prefill_tensor_parallel_size, + prefill_expert_parallel_size, + decode_instances_num, + decode_startup_params, + decode_startup_env, + decode_data_parallel_size, + decode_tensor_parallel_size, + decode_expert_parallel_size, + scheduler_policy, + proxy_host, + proxy_port, + ) + + +def _inner_deploy( + head_ip: str, + prefill_instances_num: int, + prefill_startup_params: List[str], + prefill_startup_env: Optional[str], + prefill_data_parallel_size: int, + prefill_tensor_parallel_size: int, + prefill_expert_parallel_size: int, + decode_instances_num: int, + decode_startup_params: List[str], + decode_startup_env: Optional[str], + decode_data_parallel_size: int, + decode_tensor_parallel_size: int, + decode_expert_parallel_size: int, + scheduler_policy: SchedulerPolicy, + proxy_host: str, + proxy_port: int, +): + config = ControllerConfig( + scheduler_policy=scheduler_policy, + num_prefill_instances=prefill_instances_num, + prefill_instance_config=InferenceInstanceConfig( + startup_params=prefill_startup_params, + startup_env=prefill_startup_env, + dp=prefill_data_parallel_size, + tp=prefill_tensor_parallel_size, + ep=prefill_expert_parallel_size, + ), + num_decode_instances=decode_instances_num, + decode_instance_config=InferenceInstanceConfig( + startup_params=decode_startup_params, + startup_env=decode_startup_env, + dp=decode_data_parallel_size, + tp=decode_tensor_parallel_size, + ep=decode_expert_parallel_size, + ), + ) + + """Deploy to Ray cluster""" + try: + logger.info("Connecting to existing Ray cluster at: %s", head_ip) + ray.init(address=head_ip, namespace=DLLM_NAMESPACE, + runtime_env={"worker_process_setup_hook": setup_logging}) + except Exception as e: + logger.exception("Failed to connect ray cluster: %s", str(e)) + return + + logger.info("Ray cluster resources: %s", ray.cluster_resources()) + + should_start_controller = False + try: + controller = ray.get_actor(CONTROLLER_ACTOR_NAME) + logger.exception( + "There is already an dllm controller running in the cluster, please clean dllm before " "deploy again" + ) + except ValueError: + should_start_controller = True + + if not should_start_controller: + return + + logger.info("No existing Controller found, creating new instance") + controller = ray.remote(Controller).options( + name=CONTROLLER_ACTOR_NAME, + lifetime="detached", + ).remote(config) + ray.get(controller.initialize.remote()) + logger.info("Controller actor created.") + + try: + ray.serve.shutdown() + deploy_endpoint_to_cluster(proxy_host, proxy_port) + logger.info("Deployment completed successfully") + except Exception as e: + logger.exception("Deployment failed: %s", str(e)) + + +@cli.command("clean", context_settings={"show_default": True}) +@click.option("--head-ip", type=str, help='IP of Ray head node (e.g. "10.2.3.4")', default="auto") +@click.option("--shutdown-ray-serve/--no-shutdown-ray-serve", type=bool, is_flag=True, + help="whether or not to shutdown Ray serve proxy", default=True) +def clean(head_ip, shutdown_ray_serve): + """Clean up deployment from Ray cluster""" + _inner_clean(head_ip, shutdown_ray_serve) + + +def _inner_clean(head_ip, shutdown_ray_serve): + try: + logger.info("Connecting to existing Ray cluster at: %s", head_ip) + ray.init(address=head_ip, namespace=DLLM_NAMESPACE, log_to_driver=False, + runtime_env={"worker_process_setup_hook": setup_logging}) + except Exception as e: + logger.exception("Failed to connect ray cluster: %s", str(e)) + return + + if shutdown_ray_serve: + ray.serve.shutdown() + else: + try: + ray.serve.delete(ENDPOINT_APPLICATION_NAME) + except Exception as e: + logger.warning("Cleanup endpoint failed: %s", str(e)) + + controller = None + try: + controller = ray.get_actor(CONTROLLER_ACTOR_NAME) + logger.info("Found existing Controller actor, attempting to kill it") + ray.get(controller.terminate.remote()) + except ValueError: + logger.info("No existing Controller actor found, nothing to clean") + except Exception as e: + logger.info(f"Failed to clean up controller {e}") + finally: + if controller: + ray.kill(controller) + +if __name__ == "__main__": + cli() \ No newline at end of file diff --git a/examples/online_serving/dllm_tool/dllm/utils.py b/examples/online_serving/dllm_tool/dllm/utils.py new file mode 100644 index 00000000000..0c6c39f8cd8 --- /dev/null +++ b/examples/online_serving/dllm_tool/dllm/utils.py @@ -0,0 +1,118 @@ +import errno +import glob +import socket + +import psutil +import ray + + +def ray_run_on_every_nodes(func, *args, **kwargs): + unique_ips = set( + [node["NodeManagerAddress"] for node in ray.nodes() if node["Alive"]]) + futures = [ + ray.remote(func).options(resources={ + f"node:{ip}": 0.01 + }).remote(*args, **kwargs) for ip in unique_ips + ] + return ray.get(futures) + + + +def find_node_ip(address: str = "8.8.8.8:53") -> str: + """ + NOTE: this implementation is adapted from ray-project/ray, see: + https://github.com/ray-project/ray/blob/aa2dede7f795d21407deebf4cefc61fd00e68e84/python/ray/_private/services.py#L637 + + IP address by which the local node can be reached *from* the `address`. + + Args: + address: The IP address and port of any known live service on the + network you care about. + + Returns: + The IP address by which the local node can be reached from the address. + """ + ip_address, port = address.split(":") + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + # This command will raise an exception if there is no internet + # connection. + s.connect((ip_address, int(port))) + node_ip_address = s.getsockname()[0] + except OSError as e: + node_ip_address = "127.0.0.1" + # [Errno 101] Network is unreachable + if e.errno == errno.ENETUNREACH: + try: + # try get node ip address from host name + host_name = socket.getfqdn(socket.gethostname()) + node_ip_address = socket.gethostbyname(host_name) + except Exception: + pass + finally: + s.close() + + return node_ip_address + + +def find_free_port(address: str = "") -> str: + """ + find one free port + + Returns: + port + """ + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind((address, 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return str(s.getsockname()[1]) + + +def find_interface_by_ip(ip_address): + """ + Find the network interface name associated with the given IP address. + + Args: + ip_address (str): The IP address to look up (e.g., "192.168.1.100"). + + Returns: + str: The name of the matching network interface (e.g., "eth0" or "wlan0"), or None if not found. + """ + interfaces = psutil.net_if_addrs() + + for interface_name, addresses in interfaces.items(): + for address in addresses: + if address.family == socket.AF_INET and address.address == ip_address: + return interface_name + + # Return None if no match is found + return None + + +def find_ip_by_interface(interface_name: str): + """ + Find the IP address associated with the given network interface name. + + Args: + interface_name (str): The name of the network interface (e.g., "eth0", "wlan0"). + + Returns: + str: The IP address associated with the interface, or None if not found. + """ + # Get all network interfaces and their addresses + interfaces = psutil.net_if_addrs() + + # Check if the interface exists + if interface_name not in interfaces: + return None + + # Determine the address family (IPv4 or IPv6) + family = socket.AF_INET # IPv6: 10 (AF_INET6), IPv4: 2 (AF_INET) + + # Iterate through the addresses of the specified interface + for address in interfaces[interface_name]: + if address.family == family: + return address.address + + # Return None if no matching IP address is found + return None diff --git a/examples/online_serving/dllm_tool/dllm/vllm_instance.py b/examples/online_serving/dllm_tool/dllm/vllm_instance.py new file mode 100644 index 00000000000..2792c93a7f6 --- /dev/null +++ b/examples/online_serving/dllm_tool/dllm/vllm_instance.py @@ -0,0 +1,295 @@ +import asyncio +import json +import logging +import os +import signal +import subprocess +import sys +from asyncio import Task +from typing import Optional + +import aiohttp +import ray +from ray import actor +from ray.util.placement_group import PlacementGroup +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +from dllm.config import DPConfig, EPConfig, VllmInstanceConfig +from dllm.constants import (BALANCER_ACTOR_NAME, DLLM_NAMESPACE, + INSTANCE_HEALTHCHECK_INTERVAL_SEC) +from dllm.entities import Role, VllmInstanceInfo, VllmInstanceStatus +from dllm.utils import (find_free_port, find_interface_by_ip, + find_ip_by_interface, find_node_ip) + +from vllm.platforms import current_platform + +logger = logging.getLogger(__name__) + + +def select_distributed_torch_interface(): + for env in ["GLOO_SOCKET_IFNAME", "NCCL_SOCKET_IFNAME"]: + if env in os.environ: + return os.environ[env] + + +class VllmInstance: + """ + VllmInstance is a vllm engine wrapped by a ray actor, responsibilities: + 1. start vllm api server (and pass some args) + ref: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#vllm-serve + 2. do the health check job (report to Controller if any failure) + """ + + _vllm_instance_config: VllmInstanceConfig + _vllm_instance_info: VllmInstanceInfo + #: the actor handle of balancer + _balancer_handle: Optional[actor.ActorHandle] + _vllm_api_server_process: Optional[subprocess.Popen] + _vllm_api_server_health_monitor_task: Optional[Task[None]] + + def __init__(self, name: str, vllm_config: VllmInstanceConfig): + """ + Args: + env: the environment variables pass to subprocess + exec_cmd: the vllm api server startup command, e.g. ["vllm", "serve", "--a=1", "--b=2"] + """ + assert vllm_config.pd_config is not None, "vllm instance PD config is None, abort" + self._vllm_instance_config = vllm_config + self._vllm_instance_info = VllmInstanceInfo( + id=name, uri="", role=vllm_config.pd_config.role) + self._balancer_handle = None + self._vllm_api_server_process = None + self._vllm_api_server_health_monitor_task = None + self._env = dict(os.environ) + self._env["HCCL_IF_BASE_PORT"] = os.environ.get( + 'HCCL_IF_BASE_PORT', "50000") + + self.__has_process_started = False + + async def init_dp_master_ip_port(self): + """ + if dp config is None, init dp master + """ + intf = select_distributed_torch_interface() + if intf: + ip = find_ip_by_interface(intf) + else: + ip = find_node_ip() + intf = find_interface_by_ip(ip) + assert intf is not None and ip is not None, "failed to find an available network interface for DP group communication, set env GLOO_SOCKET_IFNAME or NCCL_SOCKET_IFNAME manually and try again" + self._env["GLOO_SOCKET_IFNAME"] = intf + self._env["NCCL_SOCKET_IFNAME"] = intf + master_port = find_free_port(ip) + return ip, master_port + + async def initialize(self) -> None: + """launch subprocess""" + logger.info( + f"initialize with ASCEND_RT_VISIBLE_DEVICES: {os.environ.get('ASCEND_RT_VISIBLE_DEVICES')}" + ) + # normalize and set some env vars + self._resort_ascend_rt_visible_devices_env() + self._env["VLLM_USE_V1"] = "1" + + # init all None configs + if self._vllm_instance_config.dp_config is None: + self._vllm_instance_config.dp_config = DPConfig() + if self._vllm_instance_config.ep_config is None: + self._vllm_instance_config.ep_config = EPConfig() + + # api server options + # dp slaves have no http api server + if self._vllm_instance_config.dp_config.dp_size == 0 or self._vllm_instance_config.dp_config.dp_rank == 0: + protocol = "http" + ip = find_node_ip() + port = find_free_port() + self._vllm_instance_info.uri = f"{protocol}://{ip}:{port}" + self._vllm_instance_config.exec_cmd.extend( + ["--host", ip, "--port", str(port)]) + + # tp, pd, and dp options + self._vllm_instance_config.exec_cmd.extend( + ["--tensor-parallel-size", + str(self._vllm_instance_config.tp)]) + self._add_dp_command_options() + self._add_ep_command_options() + self._add_env() + + logger.info( + f"initialize with command: {self._vllm_instance_config.exec_cmd}, env:{self._env}" + ) + self._vllm_api_server_process = subprocess.Popen( + self._vllm_instance_config.exec_cmd, + stdout=sys.stdout, + stdin=sys.stdin, + stderr=sys.stderr, + text=True, + preexec_fn=os.setpgrp, + env=self._env, + ) + + # use a thread to check and report health status + # thread safety issue: https://github.com/ray-project/ray/issues/2385 + self._vllm_api_server_health_monitor_task = asyncio.create_task( + self._monitor_health()) + + def _resort_ascend_rt_visible_devices_env(self): + if "ASCEND_RT_VISIBLE_DEVICES" not in os.environ.keys(): + return + try: + device_ids = [ + int(id.strip()) + for id in os.environ["ASCEND_RT_VISIBLE_DEVICES"].split(",") + ] + except ValueError: + return + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = ",".join( + map(str, sorted(device_ids))) + self._env["ASCEND_RT_VISIBLE_DEVICES"] = ",".join( + map(str, sorted(device_ids))) + + def _add_dp_command_options(self): + assert self._vllm_instance_config.dp_config is not None, "vllm instance DP config is None, abort" + if not self._vllm_instance_config.dp_config.is_dp_enabled(): + return + + self._vllm_instance_config.exec_cmd.extend([ + "--data-parallel-size", + str(self._vllm_instance_config.dp_config.dp_size), + "--data-parallel-size-local", + str(self._vllm_instance_config.dp_config.dp_local_size), + "--data-parallel-start-rank", + str(self._vllm_instance_config.dp_config.dp_rank), + "--data-parallel-address", + str(self._vllm_instance_config.dp_config.dp_master_ip), + "--data-parallel-rpc-port", + str(self._vllm_instance_config.dp_config.dp_master_port), + ]) + dp_config = self._vllm_instance_config.dp_config + if dp_config and dp_config.dp_rank > 0: + self._vllm_instance_config.exec_cmd.extend(["--headless"]) + + def _add_ep_command_options(self): + assert self._vllm_instance_config.ep_config is not None, "vllm instance EP config is None, abort" + if not self._vllm_instance_config.ep_config.is_ep_enabled(): + return + + self._vllm_instance_config.exec_cmd.extend([ + "--enable-expert-parallel", + ]) + + def _add_env(self): + if self._vllm_instance_config.env is None: + return + + env_dict = dict( + item.split('=') for item in self._vllm_instance_config.env.split()) + for env_key, env_value in env_dict.items(): + self._env[env_key] = env_value + + async def _monitor_health(self): + """Asynchronously monitor subprocess health and report to controller""" + while not self._balancer_handle: + try: + self._balancer_handle = ray.get_actor(name=BALANCER_ACTOR_NAME, + namespace=DLLM_NAMESPACE) + except Exception: + logger.warning( + 'Instance get _balancer_handle failed, wait for 1 second and retry.' + ) + await asyncio.sleep(1) + + async with aiohttp.ClientSession() as session: + last_report_time = asyncio.get_event_loop().time() + last_status = self._vllm_instance_info.status + while True: + self._vllm_instance_info.status = VllmInstanceStatus.RUNNING + assert self._vllm_api_server_process is not None, "vllm api server process is not started" + if self._vllm_api_server_process.poll() is not None: + self._vllm_instance_info.status = VllmInstanceStatus.SUBPROCESS_EXITED + elif self._vllm_instance_info.uri is not None: # only check DP master's healthy + try: + async with session.get( + f"{self._vllm_instance_info.uri}/health", + timeout=aiohttp.ClientTimeout( + total=2)) as response: + self._vllm_instance_info.status = ( + VllmInstanceStatus.HEALTHCHECK_FAILED + if response.status != 200 else + VllmInstanceStatus.RUNNING) + except (aiohttp.ClientError, asyncio.TimeoutError): + self._vllm_instance_info.status = VllmInstanceStatus.HEALTHCHECK_FAILED + if ( + # not healthy + self._vllm_instance_info.status != + VllmInstanceStatus.RUNNING + # or changed + or self._vllm_instance_info.status != last_status + # or past quite long time, we should let controller know that we are still alive + or asyncio.get_event_loop().time() - last_report_time > + INSTANCE_HEALTHCHECK_INTERVAL_SEC): + await self._balancer_handle.update_vllm_instance_health.remote( + [self._vllm_instance_info + ]) # type: ignore # ray remote call + last_report_time = asyncio.get_event_loop().time() + last_status = self._vllm_instance_info.status + + if self._vllm_instance_info.status == VllmInstanceStatus.SUBPROCESS_EXITED: + # terminate self + logger.info( + "vllm subprocess exited unexpectedly, VllmInstance exit with vllm together" + ) + await asyncio.sleep(5) + + async def terminate(self, timeout_s=5): + if self._vllm_api_server_process is None: + return + + try: + pgid = os.getpgid(self._vllm_api_server_process.pid) + os.killpg(pgid, signal.SIGTERM) + except ProcessLookupError: + logger.info("process already exited") + return + + # Another way is "self._vllm_api_server_process.terminate()" + try: + self._vllm_api_server_process.wait(timeout_s) + except (TimeoutError, subprocess.TimeoutExpired): + pass + finally: + if self._vllm_api_server_process.poll() is None: + # Another way is "self._vllm_api_server_process.kill()" + os.killpg(pgid, signal.SIGKILL) + + +def start_vllm_instance( + vllm_instance_config: VllmInstanceConfig, + pg: Optional[PlacementGroup] = None) -> actor.ActorHandle: + assert vllm_instance_config.pd_config is not None, "vllm instance PD config is None, abort" + name = f"vllm-instance-{vllm_instance_config.pd_config.role.name}-{vllm_instance_config.pd_config.pd_rank}" + assert vllm_instance_config.dp_config is not None, "vllm instance DP config is None, abort" + if vllm_instance_config.dp_config.dp_size > 1: + # DP env should be set by `init_dp_config` method + name = ( + f"{name}-DP-{vllm_instance_config.dp_config.dp_rank}-" + f"{vllm_instance_config.dp_config.dp_rank+vllm_instance_config.dp_config.dp_local_size}" + ) + + actor_options = { + "resources": { + current_platform.device_name: + vllm_instance_config.dp_config.dp_local_size * + vllm_instance_config.tp + }, + "name": name, + "num_cpus": 0, + } + if pg: + actor_options[ + "scheduling_strategy"] = PlacementGroupSchedulingStrategy( + placement_group=pg, ) + + vllm_instance_actor = ray.remote(VllmInstance).options( + **actor_options).remote(name, vllm_instance_config) + return vllm_instance_actor # type: ignore # ray actor handle diff --git a/examples/online_serving/dllm_tool/pyproject.toml b/examples/online_serving/dllm_tool/pyproject.toml new file mode 100644 index 00000000000..c912b621920 --- /dev/null +++ b/examples/online_serving/dllm_tool/pyproject.toml @@ -0,0 +1,46 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel", "pybind11==2.10.3", "cmake"] +build-backend = "setuptools.build_meta" + +[project] +name = "dllm" +version = "0.1" +description = "cluster deployment tool for vllm" + +requires-python = ">=3.10" +dependencies = [ + "requests>=2.25.1", + "numpy>=1.19.2", + "fastapi", + "aiohttp", + "uvicorn", + "ray[serve]", + "click", + "psutil", +] +classifiers = [ + "Programming Language :: Python :: 3.10", + "Operating System :: OS Independent", +] + +[project.scripts] +dllm = "dllm.scripts:cli" + +[project.optional-dependencies] +dev = [ + "pytest>=7.0.0", + "black>=23.0.0", + "sphinx", + "sphinx-design", + "myst-parser", + "sphinx-click", +] +build = [ + "wheel>=0.10.0", + "twine>=4.0.0", + "build>=0.10.0", +] + +[tool.black] +line-length = 120 +target-version = ["py310", "py311"] \ No newline at end of file