Skip to content

Automatically detects RDMA devices, eliminating complex manual setup for mooncake #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/sglang/srt/disaggregation/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class KVPoll:
class KVManager:
# TODO: make it general and support multiple transfer backend before merging
def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode):
self.engine = MooncakeTransferEngine()
self.engine = MooncakeTransferEngine(args.gpu_id)
self.kv_args = args
self.disaggregation_mode = disaggregation_mode
self.request_pool: RequestPoolType = {}
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
def _init_kv_manager(self) -> KVManager:
kv_args = KVArgs()
kv_args.engine_rank = self.tp_rank
kv_args.gpu_id = self.scheduler.gpu_id
kv_data_ptrs, kv_data_lens, kv_item_lens = (
self.token_to_kv_pool.get_contiguous_buf_infos()
)
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
tp_size: int,
bootstrap_port: int,
gloo_group: ProcessGroup,
scheduler: Scheduler,
):
self.token_to_kv_pool = token_to_kv_pool
self.aux_dtype = aux_dtype
Expand All @@ -68,6 +69,7 @@ def __init__(
self.queue: List[Req] = []
self.gloo_group = gloo_group
self.bootstrap_port = bootstrap_port
self.scheduler = scheduler

def allocate_token_id(self, idx: int, token_id: int):
assert token_id >= 0, f"token_id: {token_id} is negative"
Expand All @@ -84,6 +86,7 @@ def _init_kv_manager(self) -> KVManager:
kv_args.kv_data_ptrs = kv_data_ptrs
kv_args.kv_data_lens = kv_data_lens
kv_args.kv_item_lens = kv_item_lens
kv_args.gpu_id = self.scheduler.gpu_id

# Define req -> input ids buffer
kv_args.aux_data_ptrs = [
Expand Down
152 changes: 152 additions & 0 deletions python/sglang/srt/disaggregation/rdma_device_utils.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We think it's better to put toplogy detection inside the Mooncake Transfer Engine. You can checkout our PR here and tryout. To enable this, just leave the device_name items blank.

Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#!/usr/bin/env python
# coding:utf-8
"""
@author: nivic ybyang7
@license: Apache Licence
@file: ib_devices
@time: 2025/04/03
@contact: ybyang7@iflytek.com
@site:
@software: PyCharm

# Code is far away from bugs with the god animal protecting
I love animals. They taste delicious.
┏┓ ┏┓
┏┛┻━━━┛┻┓
┃ ☃ ┃
┃ ┳┛ ┗┳ ┃
┃ ┻ ┃
┗━┓ ┏━┛
┃ ┗━━━┓
┃ God Bless ┣┓
┃ No BUG! ┏┛
┗┓┓┏━┳┓┏┛
┃┫┫ ┃┫┫
┗┻┛ ┗┻┛
"""
import os

# Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit.
# Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan.
# Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna.
# Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus.
# Vestibulum commodo. Ut rhoncus gravida arcu.
import pyverbs.device as d
import pynvml


def get_device_list(prefix, gpu_no=0, roce_version=2, port_num=1):
"""
Get a list of RDMA devices matching the specified prefix.

Args:
prefix (str): Device name prefix to filter (e.g., 'mlx')
gpu_no (int): GPU device number (default: 0)
roce_version (int): RoCE version to use (default: 2)
port_num (int): Port number to query (default: 1)

Returns:
dict: Dictionary mapping RDMA device names to their PCI addresses
"""
lst = d.get_device_list()
if len(lst) == 0:
print("No IB devices")
return []
device_list = {}
for dev in lst:
if dev.name.decode().startswith(prefix):
with d.Context(name=dev.name.decode()) as ctx:
gid_tbl_len = ctx.query_port(port_num).gid_tbl_len
if gid_tbl_len > 0:
ctx.query_gid(port_num=port_num, index=roce_version)
# Get PCI address from sysfs
dev_path = f"/sys/class/infiniband/{dev.name.decode()}/device"
if os.path.exists(dev_path):
pci_addr = os.readlink(dev_path).split("/")[-1] # Format like "0000:19:00.0"
device_list[dev.name.decode()] = pci_addr

return device_list


def get_gpu_pci_address(gpu_no):
"""
Get the PCI address of a specified GPU device.

Args:
gpu_no (int): GPU device number

Returns:
str: PCI address of the GPU device
"""
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_no)
pci_info = pynvml.nvmlDeviceGetPciInfo(handle)
pynvml.nvmlShutdown()
return pci_info.busId


def get_net_device_from_rdma(rdma_dev):
"""
Get the network interface name corresponding to a RoCE device.

Args:
rdma_dev (str): RDMA device name

Returns:
str: Network interface name or None if not found
"""
net_path = f"/sys/class/infiniband/{rdma_dev}/device/net"
if os.path.exists(net_path):
return os.listdir(net_path)[0] # Read network interface name
return None


def normalize_pci_addr(pci_addr):
"""
Standardize PCI address format.

Args:
pci_addr (str): PCI address to normalize

Returns:
str: Normalized PCI address in format "0000:08:00.0"
"""
parts = pci_addr.split(":")
if len(parts) == 3: # Format like "00000000:08:00.0"
return f"{int(parts[0], 16):04x}:{parts[1]}:{parts[2]}" # Convert to "0000:08:00.0"
return pci_addr # Return original format


def find_best_rdma_device_for_gpu(gpu_no, prefix="mlx"):
"""
Find the most affinity RoCE network card for a given GPU.

Args:
gpu_no (int): GPU device number
prefix (str): RDMA device name prefix (default: "mlx")

Returns:
tuple: (best_rdma_dev, net_dev) containing the best RDMA device and its network interface
"""
gpu_pci = normalize_pci_addr(get_gpu_pci_address(gpu_no))
roce_devices = {k: normalize_pci_addr(v) for k, v in get_device_list(prefix).items()}

best_rdma_dev = None
min_distance = float("inf")

for rdma_dev, rdma_pci in roce_devices.items():
if rdma_pci[:5] == gpu_pci[:5]: # Ensure same NUMA node
distance = abs(int(rdma_pci.split(":")[1], 16) - int(gpu_pci.split(":")[1], 16))
if distance < min_distance:
min_distance = distance
best_rdma_dev = rdma_dev

if best_rdma_dev:
net_dev = get_net_device_from_rdma(best_rdma_dev)
return best_rdma_dev, net_dev


if __name__ == '__main__':
gpu_no = 0 # GPU device number to query
rdma_dev, net_dev = find_best_roce_for_gpu(gpu_no)
print(f"GPU {gpu_no} most affinity RDMA device: {rdma_dev}, corresponding network interface: {net_dev}")
34 changes: 26 additions & 8 deletions python/sglang/srt/disaggregation/transfer_engine/mooncake.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import uuid
from dataclasses import dataclass

from sglang.srt.utils import get_local_ip_by_remote
from sglang.srt.disaggregation.rdma_device_utils import find_best_rdma_device_for_gpu

logger = logging.getLogger(__name__)


Expand All @@ -27,19 +30,36 @@ def from_file(file_path: str) -> "MooncakeTransferEngineConfig":
)

@staticmethod
def load_from_env() -> "MooncakeTransferEngineConfig":
def load_config(gpu_id=None) -> "MooncakeTransferEngineConfig":
"""Load config from a file specified in the environment variable."""
config_file_path = os.getenv("MOONCAKE_CONFIG_PATH")
if config_file_path is None:
raise ValueError(
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
)
logger.info("No config set for 'MOONCAKE_CONFIG_PATH', specified env is preferred")
return MooncakeTransferEngineConfig.auto_config(gpu_id)
return MooncakeTransferEngineConfig.from_file(config_file_path)

@staticmethod
def load_auto_config(gpu_id) -> "MooncakeTransferEngineConfig":
"""Load config from a file specified in the environment variable."""
metadata_server = os.getenv("MOONCAKE_METADATA_SERVER", None)
if metadata_server is None:
raise ValueError(
"The environment variable 'MOONCAKE_METADATA_SERVER' is not set."
)
local_hostname = os.getenv("MOONCAKE_LOCAL_HOSTNAME", default=get_local_ip_by_remote())
protocol = os.getenv("MOONCAKE_PROTOCOL", default="rdma")
default_ib_device, _ = find_best_rdma_device_for_gpu(gpu_id)
device_name = os.getenv("MOONCAKE_RDMA_DEVICE_NAME", default=default_ib_device)
return MooncakeTransferEngineConfig(
local_hostname=local_hostname,
metadata_server=metadata_server,
protocol=protocol,
device_name=device_name,
)

class MooncakeTransferEngine:

def __init__(self):
def __init__(self, gpu_id=0):
try:
from mooncake.engine import TransferEngine
except ImportError as e:
Expand All @@ -52,7 +72,7 @@ def __init__(self):
self.engine = TransferEngine()

try:
self.config = MooncakeTransferEngineConfig.load_from_env()
self.config = MooncakeTransferEngineConfig.load_auto_config(gpu_id)
logger.info("Mooncake Configuration loaded successfully.")
except ValueError as e:
logger.error(e)
Expand All @@ -61,8 +81,6 @@ def __init__(self):
logger.error("An error occurred while loading the configuration: %s", exc)
raise

self.config = MooncakeTransferEngineConfig.load_from_env()

session_suffix = "_" + str(uuid.uuid4())
self.session_id = self.config.local_hostname + session_suffix
self.initialize(
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,7 @@ def init_disaggregation(self):
tp_size=self.tp_size,
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
scheduler=self,
)
# The prefill requests that are in the middle of kv sending
self.disagg_prefill_infight_queue: List[Req] = []
Expand Down
10 changes: 10 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1828,3 +1828,13 @@ def fast_topk(values, topk, dim):
else:
# Use topk for efficiency with larger k values
return torch.topk(values, topk, dim=dim)

def get_local_ip_by_remote(addr="8.8.8.8:8888"):
"""
Get Local IP Connecting Remote Addr
"""

host, port = addr.split(":")
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect((host, int(port.strip()))) # connecting fake server to get ip host
return s.getsockname()[0]
Loading