Skip to content

WIP: Support multi prefill instances on one node #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: mooncake_transfer_engine
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 100 additions & 10 deletions python/sglang/srt/disaggregation/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

class KVArgs:
engine_rank: int
tp_size: int
kv_data_ptrs: list[int]
kv_data_lens: list[int]
kv_item_lens: list[int]
Expand Down Expand Up @@ -149,6 +150,7 @@ def send_aux(
def sync_status_to_decode_endpoint(self, remote: str, room: int):
if ":" in remote:
remote = remote.split(":")[0]
# TODO(yuan-luo): change the receiver rank port to configurable
self._connect(
"tcp://"
+ remote
Expand Down Expand Up @@ -299,10 +301,26 @@ def __init__(self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: int):
self.bootstrap_room = bootstrap_room
self.kv_mgr.set_status(bootstrap_room, KVPoll.WaitingForInput)
self.aux_index = None
self.bootstrap_server_url = bootstrap_addr

@cache
def _connect_router(self, endpoint: str):
socket = zmq.Context().socket(zmq.DEALER)
self.identity = str(uuid.uuid4()).encode()
socket.setsockopt(zmq.IDENTITY, self.identity)
socket.connect(endpoint)
return socket

def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
self.aux_index = aux_index
self.num_kv_indices = num_kv_indices
self._connect_router("tcp://" + self.bootstrap_server_url).send_multipart(
[
"Prefill".encode("ascii"),
str(self.kv_mgr.engine_rank).encode("ascii"),
str(self.kv_mgr.tp_size).encode("ascii"),
]
)

def send(self, kv_indices: npt.NDArray[np.int32]):
self.kv_mgr.enqueue_request(self.bootstrap_room, kv_indices, self.aux_index)
Expand All @@ -317,19 +335,26 @@ def failure_exception(self):
class KVReceiver:

def __init__(
self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None
self, mgr: KVManager, bootstrap_addr: str, prefill_addr: str, bootstrap_room: Optional[int] = None
):
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
self.prefill_addr = prefill_addr
self.kv_mgr = mgr
self.prefill_server_url = (
bootstrap_addr.split(":")[0]
+ ":"
+ str(KVSENDER_POLLING_PORT + self.kv_mgr.kv_args.engine_rank)
)
self.decode_ip = self.kv_mgr.get_localhost()
self.session_id = self.kv_mgr.get_session_id()
self.kv_mgr.set_status(bootstrap_room, KVPoll.WaitingForInput)
self.bootstrap_addr = bootstrap_addr
self.prefill_engine_rank = None
self.prefill_tp_size = None

@cache
def _connect_router(self, endpoint: str):
socket = zmq.Context().socket(zmq.DEALER)
self.identity = str(uuid.uuid4()).encode()
socket.setsockopt(zmq.IDENTITY, self.identity)
socket.connect(endpoint)
return socket

@cache
def _connect(self, endpoint: str):
Expand All @@ -338,13 +363,41 @@ def _connect(self, endpoint: str):
return socket

def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
self._connect("tcp://" + self.bootstrap_addr).send_multipart(
[
"Decode".encode("ascii"),
str(0).encode("ascii"),
str(0).encode("ascii"),
]
)
# Start listen clients thread
self.zmq_thread = threading.Thread(target=self._listen_server, daemon=True) # ZMQ communication thread
self.zmq_thread.start()

def _listen_server(self):
while True:
# Receive messages from bootstrap server (Prefill)
(role, engine_rank, tp_size) = self.router_socket.recv_multipart()
role = role.decode("ascii")
if role == "Decode":
self.prefill_engine_rank = int(engine_rank.decode("ascii"))
self.prefill_tp_size = int(tp_size.decode("ascii"))
self.handshake_prefill_server(kv_indices, aux_index)

def handshake_prefill_server(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
self.kv_mgr.enqueue_request(self.bootstrap_room, kv_indices, aux_index)
packed_kv_data_ptrs = b"".join(
struct.pack("q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
)
packed_aux_data_ptrs = b"".join(
struct.pack("q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
)
sender_polling_port = int(self.prefill_addr.split(":")[1]) - 1
self.prefill_server_url = (
self.bootstrap_addr.split(":")[0]
+ ":"
+ str(sender_polling_port + self.prefill_engine_rank)
)
self._connect("tcp://" + self.prefill_server_url).send_multipart(
[
self.decode_ip.encode("ascii"),
Expand All @@ -366,19 +419,56 @@ def failure_exception(self):

class KVBootstrapServer:
def __init__(self, port: int):
self.port = port
self.route_port = port
self.app = web.Application()
self.store = dict()
self.lock = asyncio.Lock()
self._setup_routes()

self.context = zmq.Context()

# ROUTER socket to communicate with Prefill and Decode
self.router_socket = self.context.socket(zmq.ROUTER)
self.router_socket.bind(f"tcp://*:{self.route_port}")

self.prefill_engine_rank = None
self.prefill_tp_size = None

# Start bootstrap server
self.thread = threading.Thread(target=self._run_server, daemon=True)
self.run()

def run(self):
self.thread.start()

def _listen_clients(self):
while True:
# Receive messages from clients (Prefill or Decode)
(role, engine_rank, tp_size) = self.router_socket.recv_multipart()
role = role.decode("ascii")
engine_rank = int(engine_rank.decode("ascii"))
tp_size = int(tp_size.decode("ascii"))
if role == "Prefill":
self._handle_prefill(engine_rank, tp_size)
elif role == "Decode":
self._handle_decode()

def _handle_prefill(self, engine_rank, tp_size):
"""Handle Prefill message"""
self.prefill_engine_rank = engine_rank
self.prefill_tp_size = tp_size

def _handle_decode(self):
"""Handle Decode message"""
if self.prefill_engine_rank is None or self.prefill_tp_size is None:
print("Error: Metadata not yet received from Prefill.")
self.router_socket.send_multipart([identity, b"Error: Metadata not ready"])
else:
self.router_socket.send_multipart(
[
"Decode".encode("ascii"),
str(prefill_engine_rank).encode("ascii"),
str(prefill_tp_size).encode("ascii"),
]
)

def _setup_routes(self):
self.app.router.add_route("*", "/metadata", self._handle_metadata)

Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
def _init_kv_manager(self) -> KVManager:
kv_args = KVArgs()
kv_args.engine_rank = self.tp_rank
kv_args.tp_size = self.tp_size
kv_data_ptrs, kv_data_lens, kv_item_lens = (
self.token_to_kv_pool.get_contiguous_buf_infos()
)
Expand All @@ -124,8 +125,9 @@ def add(self, req: Req) -> None:

kv_receiver = KVReceiver(
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
prefill_addr=f"{req.prefill_host}:{req.prefill_port}"
)
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))

Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/disaggregation/mini_lb.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,18 @@ async def generate_request(self, request_data):
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = parsed_url.hostname
bootstrap_host = f"{hostname}"
bootstrap_port = int(parsed_url.port) + 1
prefill_host = f"{hostname}"
prefill_port = int(parsed_url.port)

modified_request = request_data.copy()
modified_request.update(
{
"bootstrap_host": bootstrap_host,
"bootstrap_port": bootstrap_port,
"bootstrap_room": random.randint(0, 2**63 - 1),
"prefill_host": prefill_host,
"prefill_port": prefill_port,
}
)

Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def __init__(
aux_dtype: torch.dtype,
tp_rank: int,
tp_size: int,
bootstrap_port: int,
gloo_group: ProcessGroup,
):
self.token_to_kv_pool = token_to_kv_pool
Expand All @@ -67,7 +66,6 @@ def __init__(
self.kv_manager = self._init_kv_manager()
self.queue: List[Req] = []
self.gloo_group = gloo_group
self.bootstrap_port = bootstrap_port

def allocate_token_id(self, idx: int, token_id: int):
assert token_id >= 0, f"token_id: {token_id} is negative"
Expand All @@ -77,6 +75,7 @@ def allocate_token_id(self, idx: int, token_id: int):
def _init_kv_manager(self) -> KVManager:
kv_args = KVArgs()
kv_args.engine_rank = self.tp_rank
kv_args.tp_size = self.tp_size
kv_data_ptrs, kv_data_lens, kv_item_lens = (
self.token_to_kv_pool.get_contiguous_buf_infos()
)
Expand All @@ -102,8 +101,9 @@ def _init_kv_manager(self) -> KVManager:
def add(self, req: Req) -> None:
req.disagg_kv_sender = KVSender(
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
prefill_addr=f"{req.prefill_host}:{req.prefill_port}"
)
self._process_req(req)
self.queue.append(req)
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ class GenerateReqInput:

# For disaggregated inference
bootstrap_host: Optional[str] = None
bootstrap_port: Optional[int] = None
bootstrap_room: Optional[int] = None
prefill_host: Optional[str] = None
prefill_port: Optional[int] = None

def normalize_batch_and_arguments(self):
if (
Expand Down Expand Up @@ -306,7 +309,10 @@ class TokenizedGenerateReqInput:

# For disaggregated inference
bootstrap_host: Optional[str] = None
bootstrap_port: Optional[int] = None
bootstrap_room: Optional[int] = None
prefill_host: Optional[str] = None
prefill_port: Optional[int] = None


@dataclass
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,10 @@ def __init__(
return_hidden_states: bool = False,
eos_token_ids: Optional[Set[int]] = None,
bootstrap_host: Optional[str] = None,
bootstrap_port: Optional[int] = None,
bootstrap_room: Optional[int] = None,
prefill_host: Optional[str] = None,
prefill_port: Optional[int] = None,
):
# Input and output info
self.rid = rid
Expand Down Expand Up @@ -487,7 +490,10 @@ def __init__(

# For disaggregation
self.bootstrap_host: str = bootstrap_host
self.bootstrap_port: int = bootstrap_port
self.bootstrap_room: Optional[int] = bootstrap_room
self.prefill_host: str = prefill_host
self.prefill_port: int = prefill_port
self.disagg_kv_sender: Optional[KVSender] = None

# used for warmup because we don't have a pair yet when init
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,10 @@ def handle_generate_request(
return_hidden_states=recv_req.return_hidden_states,
eos_token_ids=self.model_config.hf_eos_token_id,
bootstrap_host=recv_req.bootstrap_host,
bootstrap_port=recv_req.bootstrap_port,
bootstrap_room=recv_req.bootstrap_room,
prefill_host=recv_req.prefill_host,
prefill_port=recv_req.prefill_port,
)
req.tokenizer = self.tokenizer

Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,10 @@ async def _tokenize_one_request(
token_ids_logprob,
obj.stream,
bootstrap_host=obj.bootstrap_host,
bootstrap_port=obj.bootstrap_port,
bootstrap_room=obj.bootstrap_room,
prefill_host=obj.prefill_host,
prefill_port=obj.prefill_port,
lora_path=obj.lora_path,
input_embeds=input_embeds,
session_params=session_params,
Expand Down
7 changes: 0 additions & 7 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ class ServerArgs:

# For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
disaggregation_mode: str = "null"
disaggregation_bootstrap_port: int = 8998

def __post_init__(self):
# Expert parallelism
Expand Down Expand Up @@ -1157,12 +1156,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
choices=["null", "prefill", "decode"],
help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
)
parser.add_argument(
"--disaggregation-bootstrap-port",
type=int,
default=ServerArgs.disaggregation_bootstrap_port,
help="Bootstrap server port on the prefill server. Default is 8998.",
)

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
Expand Down
Loading