Skip to content

A developer friendly tool for multi-instance deployment with Ray Implementation #20761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
244 changes: 244 additions & 0 deletions examples/online_serving/dllm_tool/dllm/balancer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
import asyncio
import logging
import time
from typing import Dict, List, Optional

Check failure on line 4 in examples/online_serving/dllm_tool/dllm/balancer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (UP035)

examples/online_serving/dllm_tool/dllm/balancer.py:4:1: UP035 `typing.List` is deprecated, use `list` instead

Check failure on line 4 in examples/online_serving/dllm_tool/dllm/balancer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (UP035)

examples/online_serving/dllm_tool/dllm/balancer.py:4:1: UP035 `typing.Dict` is deprecated, use `dict` instead

import aiohttp
import ray
from prometheus_client.parser import text_string_to_metric_families
from ray import actor

from dllm import constants
from dllm.constants import CONTROLLER_ACTOR_NAME, DLLM_NAMESPACE
from dllm.entities import (DispatchResult, MetricsInfo, Role, SchedulerPolicy,
VllmInstanceInfo, VllmInstanceStatus)

logger = logging.getLogger(__name__)


class Balancer:

def __init__(
self,
policy: SchedulerPolicy = SchedulerPolicy.ROUND_ROBIN,
):
self.policy = policy
self.role_2_instances: Dict[Role, List[VllmInstanceInfo]] = {
} # prefill/decode/mixed => VllmInstanceInfo
self.instance_infos: Dict[str, VllmInstanceInfo] = {
} # id -> VllmInstanceInfo
self.instance_metrics: Dict[str, MetricsInfo] = {} # id -> MetricsInfo
self._round_robin_index_p = 0
self._round_robin_index_d = 0

Check failure on line 32 in examples/online_serving/dllm_tool/dllm/balancer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (UP006)

examples/online_serving/dllm_tool/dllm/balancer.py:32:43: UP006 Use `list` instead of `List` for type annotation

Check failure on line 32 in examples/online_serving/dllm_tool/dllm/balancer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (UP006)

examples/online_serving/dllm_tool/dllm/balancer.py:32:32: UP006 Use `dict` instead of `Dict` for type annotation
self._round_robin_index_m = 0
self.last_heartbeat: Dict[str, float] = {}

Check failure on line 34 in examples/online_serving/dllm_tool/dllm/balancer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (UP006)

examples/online_serving/dllm_tool/dllm/balancer.py:34:30: UP006 Use `dict` instead of `Dict` for type annotation
self._controller_handle: Optional[actor.ActorHandle] = None
self.all_instances_ready = False

Check failure on line 36 in examples/online_serving/dllm_tool/dllm/balancer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (UP006)

examples/online_serving/dllm_tool/dllm/balancer.py:36:32: UP006 Use `dict` instead of `Dict` for type annotation
# start update metrics loop
loop = asyncio.get_event_loop()
loop.create_task(self.update_vllm_instance_metrics())

Check failure on line 40 in examples/online_serving/dllm_tool/dllm/balancer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (UP006)

examples/online_serving/dllm_tool/dllm/balancer.py:40:30: UP006 Use `dict` instead of `Dict` for type annotation
async def update_vllm_instance_metrics(self):
while True:
try:
async with aiohttp.ClientSession() as session:
await asyncio.gather(
*[
self._query_instance_metrics(
session, instance_info)
for instance_info in self.instance_infos.values()
if instance_info.uri is not None
],
return_exceptions=True,
)
await asyncio.sleep(constants.METRICS_UPDATE_CYCLE)
except Exception as e:
logger.error("create request session error: %s", e)

def dispatch_request(self) -> DispatchResult:
if self.policy == SchedulerPolicy.ROUND_ROBIN:
return self._round_robin_pair()
else:
raise ValueError(f"Unsupported policy: {self.policy}")

def get_all_instance(self) -> Dict[str, VllmInstanceInfo]:
'''Return all vllm instance.'''
return self.instance_infos

async def _query_instance_metrics(self, session, instance_info):
ins_uri = instance_info.uri
ins_id = instance_info.id

Check failure on line 70 in examples/online_serving/dllm_tool/dllm/balancer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (UP006)

examples/online_serving/dllm_tool/dllm/balancer.py:70:35: UP006 Use `dict` instead of `Dict` for type annotation
async with session.post(f"{ins_uri}/metrics", timeout=3) as resp:
resp_code = resp.status
if resp_code != constants.HTTP_OK:
logger.error(
f"get metrics failed, uri:{ins_uri}, code:{resp_code}")
return
resp_body = await resp.text()
# {metric_name: metric_value}
metrics_dict = {
metric_family.name: metric_family.samples[0].value
for metric_family in text_string_to_metric_families(resp_body)

Check failure on line 81 in examples/online_serving/dllm_tool/dllm/balancer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

examples/online_serving/dllm_tool/dllm/balancer.py:81:21: G004 Logging statement uses f-string
if metric_family.name in MetricsInfo.METRIC_NAME_MAPPING.
values() and metric_family.samples
}
if not metrics_dict:
return
if ins_id not in self.instance_metrics:
self.instance_metrics[ins_id] = MetricsInfo()
metric_info = self.instance_metrics[ins_id]
for param_name, metric_name in MetricsInfo.METRIC_NAME_MAPPING.items(
):
if metric_name not in metrics_dict:
continue
# data type conversion
target_type = metric_info.__annotations__[param_name]
setattr(metric_info, param_name,
target_type(metrics_dict[metric_name]))
logger.debug("instance metrics info: %s", self.instance_metrics)

def _round_robin_pair(self) -> DispatchResult:
# current policy: if has mixed, use mixed
is_pd_disagged = Role.MIXED not in self.role_2_instances or len(
self.role_2_instances[Role.MIXED]) == 0
if not is_pd_disagged:
mixed_uri = self._round_robin_selection(Role.MIXED)
return DispatchResult(prefill_uri=None, decode_uri=mixed_uri)

prefill_uri = self._round_robin_selection(Role.PREFILL)
decode_uri = self._round_robin_selection(Role.DECODE)
return DispatchResult(prefill_uri=prefill_uri, decode_uri=decode_uri)

def _round_robin_selection(self, role: Role) -> str:
instances = [
item.uri for i, item in self.instance_infos.items()
if item.role == role and item.uri is not None
]
if role == Role.PREFILL:
instance = instances[self._round_robin_index_p]
self._round_robin_index_p = (self._round_robin_index_p +
1) % len(instances)
if role == Role.DECODE:
instance = instances[self._round_robin_index_d]
self._round_robin_index_d = (self._round_robin_index_d +
1) % len(instances)
if role == Role.MIXED:
instance = instances[self._round_robin_index_m]
self._round_robin_index_m = (self._round_robin_index_m +
1) % len(instances)
return instance

def update_vllm_instance_info(self, infos: List[VllmInstanceInfo]):
for item in infos:
self.instance_infos[item.id] = item
self.instance_metrics[item.id] = MetricsInfo()

# reconstruct the role map
self.role_2_instances.clear()

Check failure on line 137 in examples/online_serving/dllm_tool/dllm/balancer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (UP006)

examples/online_serving/dllm_tool/dllm/balancer.py:137:48: UP006 Use `list` instead of `List` for type annotation
for _, instance_info in self.instance_infos.items():
if instance_info.role not in self.role_2_instances:
self.role_2_instances[instance_info.role] = []
self.role_2_instances[instance_info.role].append(instance_info)

async def update_vllm_instance_health(
self, vllm_instance_info: List[VllmInstanceInfo]) -> bool:
"""
Update health status of VLLM instances.

Args:
vllm_instance_info: List of VllmInstanceInfo objects containing information

Returns:
bool: True if update was successful
"""

current_time = time.time()
for item in vllm_instance_info:
self.instance_infos[item.id] = item
self.last_heartbeat[item.id] = current_time
return True

async def is_all_instances_ready(self):
"""
Wait until all VLLM actor running status

Returns:
No return value. End of function when all actor ready,.
"""
if not self._controller_handle:
try:
self._controller_handle = ray.get_actor(
name=CONTROLLER_ACTOR_NAME, namespace=DLLM_NAMESPACE)
except BaseException:
logger.error('get _controller_handle fail')
_get_expected_vllm_actors_num = await self._controller_handle._get_expected_vllm_actors_num.remote( # type: ignore # ray remote call
)
while self._get_ready_vllm_actors_num(
) < _get_expected_vllm_actors_num:
try:
logger.debug(
f"expect {self._get_ready_vllm_actors_num()} waiting vllm actor, "
f"{self.instance_infos}")
for s in self.instance_infos.values():
if s.status == VllmInstanceStatus.SUBPROCESS_EXITED:
raise RuntimeError(
f"vllm instance: {s} exited unexpectedly")
await asyncio.sleep(1)
except Exception as e:
logger.error(
f"An error when waiting vllm instances ready: {e}")
return
logger.info("All actors are already")
self.all_instances_ready = True
asyncio.create_task(self._monitor_instance_health())

def _get_ready_vllm_actors_num(self):
"""
Get the number of ready VLLM instances.

Returns:
Number of ready VLLM instances.
"""
return sum(info.status == VllmInstanceStatus.RUNNING
for info in self.instance_infos.values())

def _get_unready_vllm_actors_num(self):
"""
Get the number of unready VLLM instances.

Returns:
Number of unready VLLM instances.
"""
return sum(info.status != VllmInstanceStatus.RUNNING
for info in self.instance_infos.values())

async def _monitor_instance_health(self):
"""
Monitor instance health, report to controller if >20s no response / failed status
"""
while True:
if self.all_instances_ready:
current_time = time.time()
for info in self.instance_infos.values():
logger.info(
f"Monitoring ID: {info.id}, Status: {info.status}")
if info.status == VllmInstanceStatus.HEALTHCHECK_FAILED:
logger.error(
f"Instance {info.id} has failed health check.")
self._controller_handle.report_failure_from_balancer.remote( # type: ignore # ray remote call
info.id)
self.all_instances_ready = False
# Consider instance unhealthy if no heartbeat
elif current_time - self.last_heartbeat.get(info.id,
0) > 20:
logger.error(
f"Instance {info.id} is unhealthy (no heartbeat).")
self._controller_handle.report_failure_from_balancer.remote( # type: ignore # ray remote call
info.id)
self.all_instances_ready = False
else:
logger.info(
"Waiting for all instances ready and restart the health monitoring."
)
await asyncio.sleep(5)
await asyncio.sleep(1)
142 changes: 142 additions & 0 deletions examples/online_serving/dllm_tool/dllm/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from dataclasses import dataclass
from typing import List, Optional

from dllm.entities import Role, SchedulerPolicy

required_vllm_options = [
("host", ),
("port", ),
("tensor-parallel-size", "tp"),
("data-parallel-size", "dp"),
("data-parallel-size-local", "dpl"),
("data-parallel-start-rank", "dpr"),
("data-parallel-address", "dpa"),
("data-parallel-rpc-port", "dpp"),
("headless", ),
("enable-expert-parallel", ),
("disable-expert-parallel", ),
("kv-transfer-config", ),
]


class AutoValidator:

def __post_init__(self):
for name, f in self.__dataclass_fields__.items( # type: ignore # get all data fields from a dataclass
):
if method := getattr(self, f"_validate_{name}", None):
method()


@dataclass
class InferenceInstanceConfig(AutoValidator):
startup_params: List[str]
startup_env: Optional[str]
tp: int
dp: int
ep: int

def _validate_startup_params(self):

def __contain_long_options(opname, params):
underline_op = opname.replace("-", "_")
return any(
p == f"--{opname}" or p.startswith(f"--{opname}=") or p ==
f"--{underline_op}" or p.startswith(f"--{underline_op}=")
for p in params)

def __contain_short_options(opname, params):
underline_op = opname.replace("-", "_")
return any(
p == f"-{opname}" or p.startswith(f"-{opname}=")
or p == f"-{underline_op}" or p.startswith(f"-{underline_op}=")
for p in params)

bad_options = []
for option in required_vllm_options:
if len(option) > 0 and __contain_long_options(
option[0], self.startup_params):
bad_options.append(option[0])
if len(option) > 1 and __contain_short_options(
option[1], self.startup_params):
bad_options.append(option[1])

if bad_options:
raise ValueError(
f"{bad_options} should not be specified in start up commands, instead, dllm will populate options after verification"
)

def _validate_ep(self):
if self.ep < 0: # type: ignore
raise ValueError(
"expert parallel size should be 0 (EP disabled) or >1 (EP enabled)"
)

def _validate_dp(self):
if not self.dp > 0: # type: ignore
raise ValueError("data parallel size should be greater than 0")

def _validate_tp(self):
if not self.tp > 0: # type: ignore
raise ValueError("tensor parallel size should be greater than 0")


@dataclass
class ControllerConfig(AutoValidator):
scheduler_policy: SchedulerPolicy
num_prefill_instances: int
num_decode_instances: int
prefill_instance_config: InferenceInstanceConfig
decode_instance_config: InferenceInstanceConfig

def _validate_num_prefill_instances(self):
if self.num_prefill_instances < 0:
raise ValueError(
"number of prefill instances should be equal to or greater than 0"
)

def _validate_num_decode_instances(self):
if self.num_decode_instances < 0:
raise ValueError(
"number of decode instances should be equal to or greater than 0"
)


@dataclass
class PDDistConfig(AutoValidator):
role: Role
pd_rank: int = 0
pd_size: int = 0

def is_pd_dist(self):
return self.role != Role.MIXED


@dataclass
class DPConfig(AutoValidator):
dp_rank: int = 0
dp_size: int = 1
dp_local_size: int = 1
dp_master_ip: str = ""
dp_master_port: int = 0

def is_dp_enabled(self):
return self.dp_size and self.dp_size > 1


@dataclass
class EPConfig(AutoValidator):
ep_size: int = 0

def is_ep_enabled(self):
return self.ep_size and self.ep_size > 0


@dataclass
class VllmInstanceConfig(AutoValidator):
exec_cmd: list[str]
env: Optional[str] = None
tp: int = 1
dp_config: Optional[DPConfig] = None
ep_config: Optional[EPConfig] = None
pd_config: Optional[PDDistConfig] = None
Loading