Skip to content

Commit 68b4755

Browse files
[LLM] support multi node deploy (#2708)
* [LLM] support multi node deploy * Update engine.py * fix bugs * fix * [LLM] support multi node deploy * [LLM] support multi node deploy --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
1 parent 04a8e1e commit 68b4755

File tree

13 files changed

+157
-87
lines changed

13 files changed

+157
-87
lines changed

fastdeploy/cache_manager/cache_messager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class CacheMessager(object):
3737
def __init__(self,
3838
splitwise_role,
3939
transfer_protocol,
40+
pod_ip,
4041
engine_worker_queue_port,
4142
local_data_parallel_id,
4243
gpu_cache_kvs,
@@ -69,7 +70,7 @@ def __init__(self,
6970
self.gpu_cache_kvs = gpu_cache_kvs
7071
self.rank = rank
7172
self.nranks = nranks
72-
address = ('0.0.0.0', engine_worker_queue_port)
73+
address = (pod_ip, engine_worker_queue_port)
7374
self.engine_worker_queue = EngineWorkerQueue(
7475
address=address,
7576
is_server=False,

fastdeploy/cache_manager/cache_transfer_manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ def parse_args():
7171
type=int,
7272
default=9923,
7373
help="cache queue port")
74+
parser.add_argument("--pod_ip",
75+
type=str,
76+
default="0.0.0.0",
77+
help="pod ip")
7478
parser.add_argument("--engine_worker_queue_port",
7579
type=int,
7680
default=9923,
@@ -144,7 +148,7 @@ def __init__(self, args):
144148
self.rank = rank
145149
self.device = device
146150

147-
address = ('0.0.0.0', args.cache_queue_port)
151+
address = (args.pod_ip, args.cache_queue_port)
148152
self.cache_task_queue = EngineCacheQueue(
149153
address=address,
150154
is_server=False,
@@ -236,6 +240,7 @@ def __init__(self, args):
236240
self.cache_messager = CacheMessager(
237241
splitwise_role=args.splitwise_role,
238242
transfer_protocol=args.protocol,
243+
pod_ip=args.pod_ip,
239244
engine_worker_queue_port=args.engine_worker_queue_port,
240245
local_data_parallel_id=args.local_data_parallel_id,
241246
gpu_cache_kvs=self.gpu_cache_kvs,

fastdeploy/cache_manager/prefix_cache_manager.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __init__(self,
109109

110110

111111
def launch_cache_manager(self, cache_config, tensor_parallel_size, \
112-
device_ids, engine_worker_queue_port, pid_suffix):
112+
device_ids, pod_ip, engine_worker_queue_port, pid_suffix):
113113
"""
114114
launch_cache_manager function used to initialize the cache manager.
115115
"""
@@ -123,7 +123,7 @@ def launch_cache_manager(self, cache_config, tensor_parallel_size, \
123123
create=True)
124124

125125
self.cache_task_queue = EngineCacheQueue(
126-
address=('127.0.0.1', cache_config.cache_queue_port),
126+
address=(pod_ip, cache_config.cache_queue_port),
127127
authkey=b'cache_queue_service',
128128
is_server=False,
129129
num_client=tensor_parallel_size,
@@ -166,6 +166,7 @@ def launch_cache_manager(self, cache_config, tensor_parallel_size, \
166166
f" --cache_dtype {cache_config.cache_dtype}" +
167167
f" --cache_queue_port {cache_config.cache_queue_port}" +
168168
f" --enable_splitwise {int(self.enable_splitwise)}" +
169+
f" --pod_ip {pod_ip}" +
169170
f" --engine_worker_queue_port {engine_worker_queue_port}" +
170171
f" --num_gpu_blocks {cache_config.total_block_num}" +
171172
f" --num_cpu_blocks {cache_config.num_cpu_blocks}" +

fastdeploy/engine/args_utils.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,7 @@ class EngineArgs:
122122
"""
123123
Ratio of tokens to process in a block.
124124
"""
125-
nnode: int = 1
126-
"""
127-
Number of nodes in the cluster.
128-
"""
125+
129126
pod_ips: Optional[List[str]] = None
130127
"""
131128
List of IP addresses for nodes in the cluster.
@@ -485,10 +482,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
485482
default=EngineArgs.pod_ips,
486483
help=
487484
"List of IP addresses for nodes in the cluster (comma-separated).")
488-
system_group.add_argument("--nnode",
489-
type=int,
490-
default=EngineArgs.nnode,
491-
help="Number of nodes in the cluster.")
485+
492486

493487
# Performance tuning parameters group
494488
perf_group = parser.add_argument_group("Performance Tuning")
@@ -773,7 +767,6 @@ def create_engine_config(self) -> Config:
773767
max_num_seqs=self.max_num_seqs,
774768
speculative_config=speculative_cfg,
775769
max_num_batched_tokens=self.max_num_batched_tokens,
776-
nnode=self.nnode,
777770
pod_ips=self.pod_ips,
778771
use_warmup=self.use_warmup,
779772
engine_worker_queue_port=self.engine_worker_queue_port,

fastdeploy/engine/config.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,6 @@ def __init__(
505505
model_name_or_path: str = None,
506506
tokenizer: str = None,
507507
tensor_parallel_size: int = 8,
508-
nnode: int = 1,
509508
max_model_len: int = 8192,
510509
max_num_seqs: int = 8,
511510
max_num_batched_tokens: Optional[int] = None,
@@ -539,7 +538,6 @@ def __init__(
539538
model_name_or_path (str): Model directory path or model name.
540539
tokenizer (str): Default is the model.
541540
tensor_parallel_size (int): Tensor parallel size. Default is 8.
542-
nnode (int): Number of nodes. Default is 1.
543541
max_model_len (int): Maximum model length. Default is 8192.
544542
max_num_seqs (int): Maximum number of sequences. Default is 8.
545543
max_num_batched_tokens (Optional[int]): Maximum number of batched tokens. Default is None.
@@ -565,7 +563,6 @@ def __init__(
565563
self.tokenizer = tokenizer
566564
self.max_num_batched_tokens = max_num_batched_tokens
567565
self.tensor_parallel_size = tensor_parallel_size
568-
self.nnode = nnode
569566
self.pod_ips = pod_ips
570567
self.max_model_len = max_model_len
571568
self.max_num_seqs = max_num_seqs
@@ -585,12 +582,15 @@ def __init__(
585582
self.max_capture_batch_size = max_capture_batch_size
586583
self.guided_decoding_backend = guided_decoding_backend
587584
self.disable_any_whitespace = disable_any_whitespace
585+
self.is_master = True
586+
self._str_to_list("innode_prefill_ports", int)
587+
self._str_to_list("pod_ips", str)
588588

589-
if self.innode_prefill_ports is not None:
590-
if not isinstance(self.innode_prefill_ports, list):
591-
ports = str(self.innode_prefill_ports).split(',')
592-
self.innode_prefill_ports = [int(port) for port in ports]
593-
589+
if self.pod_ips is None:
590+
self.nnode = 1
591+
else:
592+
self.nnode = len(self.pod_ips)
593+
594594
assert self.splitwise_role in ["mixed", "prefill", "decode"]
595595

596596
# TODO
@@ -609,14 +609,15 @@ def __init__(
609609

610610
num_ranks = self.tensor_parallel_size * self.parallel_config.expert_parallel_size
611611
if num_ranks > 8:
612-
local_num_ranks = 8
613-
self.nnode = ceil_div(num_ranks, local_num_ranks)
612+
self.worker_num_per_node = 8
613+
nnode = ceil_div(num_ranks, self.worker_num_per_node)
614+
assert nnode == self.nnode, \
615+
f"nnode: {nnode}, but got {self.nnode}"
614616
else:
615-
local_num_ranks = num_ranks
617+
self.worker_num_per_node = num_ranks
616618

617619
self.engine_worker_queue_port = engine_worker_queue_port
618-
self.device_ids = ",".join([str(i) for i in range(min((self.tensor_parallel_size * \
619-
self.parallel_config.expert_parallel_size), 8))])
620+
self.device_ids = ",".join([str(i) for i in range(self.worker_num_per_node)])
620621
self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids)
621622

622623
self.read_from_config()
@@ -628,16 +629,21 @@ def postprocess(self):
628629
"""
629630
calculate some parameters
630631
"""
631-
total_rank = self.tensor_parallel_size * self.parallel_config.expert_parallel_size
632-
assert self.device_ids.split(',').__len__() == min(total_rank, 8), \
633-
f"invalid CUDA_VISIBLE_DEVICES, should be equal to {min(total_rank, 8)}"
632+
assert self.device_ids.split(',').__len__() == self.worker_num_per_node, \
633+
f"invalid CUDA_VISIBLE_DEVICES, should be equal to {self.worker_num_per_node}"
634+
635+
assert self.worker_num_per_node % self.tensor_parallel_size == 0, \
636+
f"tensor_parallel_size: {self.tensor_parallel_size} should be divisible by worker_num_per_node: {self.worker_num_per_node}"
634637
self.local_device_ids = self.device_ids.split(
635638
',')[:self.tensor_parallel_size]
636-
assert self.tensor_parallel_size % self.nnode == 0, \
637-
f"tensor_parallel_size: {self.tensor_parallel_size} should be divisible by nnode: {self.nnode}"
638-
self.worker_num_per_node = total_rank // self.nnode
639+
639640
self.host_ip = get_host_ip()
640641

642+
if self.pod_ips is None:
643+
self.pod_ips = ["0.0.0.0"]
644+
elif self.host_ip != self.pod_ips[0]:
645+
self.is_master = False
646+
641647
import paddle
642648
self.paddle_commit_id = paddle.version.commit
643649

@@ -808,5 +814,16 @@ def reset_value(cls, value_name, key):
808814
"return_full_hidden_states")
809815
reset_value(self.cache_config, "cache_dtype", "infer_model_dtype")
810816

817+
def _check_master(self):
818+
return self.is_master
819+
820+
def _str_to_list(self, attr_name, default_type):
821+
if hasattr(self, attr_name):
822+
val = getattr(self, attr_name)
823+
if type(val) is str:
824+
setattr(self, attr_name, [default_type(i) for i in val.split(",")])
825+
else:
826+
setattr(self, attr_name, val)
827+
811828
def __str__(self) -> str:
812829
return json.dumps(self.__dict__, indent=4)

fastdeploy/engine/engine.py

Lines changed: 49 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -98,30 +98,7 @@ def __init__(self, cfg):
9898
cfg.mm_processor_kwargs,
9999
cfg.enable_mm)
100100

101-
address = ('0.0.0.0', self.cfg.engine_worker_queue_port)
102-
self.engine_worker_queue_server = EngineWorkerQueue(
103-
address=address,
104-
is_server=True,
105-
num_client=self.cfg.tensor_parallel_size,
106-
local_data_parallel_size=self.cfg.parallel_config.
107-
data_parallel_size)
108-
109-
self.engine_worker_queue = EngineWorkerQueue(
110-
address=address,
111-
is_server=False,
112-
num_client=self.cfg.tensor_parallel_size,
113-
client_id=0,
114-
local_data_parallel_id=0)
115-
116-
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != 'mixed':
117-
self.cache_task_queue = EngineCacheQueue(
118-
address=('127.0.0.1', self.cfg.cache_config.cache_queue_port),
119-
authkey=b'cache_queue_service',
120-
is_server=True,
121-
num_client=self.cfg.tensor_parallel_size,
122-
client_id=-1,
123-
local_data_parallel_size=self.cfg.parallel_config.
124-
data_parallel_size)
101+
self.start_queue_service()
125102

126103
self.resource_manager = ResourceManager(cfg.max_num_seqs, cfg,
127104
cfg.tensor_parallel_size,
@@ -198,9 +175,12 @@ def start(self, api_server_pid=None):
198175
or self.cfg.splitwise_role != "mixed"):
199176
device_ids = self.cfg.device_ids.split(",")
200177
self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager(
201-
self.cfg.cache_config, self.cfg.tensor_parallel_size,
202-
device_ids, self.cfg.engine_worker_queue_port,
203-
self.ipc_signal_suffix)
178+
cache_config=self.cfg.cache_config,
179+
tensor_parallel_size=self.cfg.tensor_parallel_size,
180+
device_ids=device_ids,
181+
pod_ip=self.cfg.pod_ips[0],
182+
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
183+
pid_suffix=self.ipc_signal_suffix)
204184

205185
self.worker_proc = self._start_worker_service()
206186
console_logger.info("Waitting worker processes ready...")
@@ -850,10 +830,7 @@ def _init_worker_signals(self):
850830
Initialize shared memory to indicate engine status
851831
"""
852832
# worker_ready_signatensor_parallel_size
853-
array_size = min(
854-
8, self.cfg.tensor_parallel_size *
855-
self.cfg.parallel_config.data_parallel_size)
856-
worker_ready_signal_data = np.zeros(shape=[array_size], dtype=np.int32)
833+
worker_ready_signal_data = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32)
857834
self.worker_ready_signal = IPCSignal(name="worker_ready_signal",
858835
array=worker_ready_signal_data,
859836
dtype=np.int32,
@@ -889,7 +866,7 @@ def _init_worker_signals(self):
889866
create=True)
890867

891868
# worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间
892-
worker_healthy_live_recorded_time_array = np.zeros(shape=[array_size],
869+
worker_healthy_live_recorded_time_array = np.zeros(shape=[self.cfg.worker_num_per_node],
893870
dtype=np.int32)
894871
self.worker_healthy_live_signal = IPCSignal(
895872
name="worker_healthy_live_signal",
@@ -899,7 +876,7 @@ def _init_worker_signals(self):
899876
create=True)
900877

901878
if self.do_profile:
902-
get_profile_block_num = np.zeros([array_size], dtype=np.int32)
879+
get_profile_block_num = np.zeros([self.cfg.worker_num_per_node], dtype=np.int32)
903880
self.get_profile_block_num_signal = IPCSignal(
904881
name="get_profile_block_num",
905882
array=get_profile_block_num,
@@ -1028,13 +1005,15 @@ def _start_worker_service(self):
10281005

10291006
arguments = (
10301007
f" --nnodes {str(self.cfg.nnode)}"
1008+
f" --ips {','.join(self.cfg.pod_ips)}"
10311009
f" --devices {self.cfg.device_ids} {py_script}"
10321010
f" --max_num_seqs {self.cfg.max_num_seqs} --max_model_len {self.cfg.max_model_len}"
10331011
f" --gpu_memory_utilization {self.cfg.cache_config.gpu_memory_utilization}"
10341012
f" --model_name_or_path {str(self.cfg.model_name_or_path)}"
10351013
f" --device_ids {self.cfg.device_ids}"
10361014
f" --tensor_parallel_size {self.cfg.tensor_parallel_size}"
10371015
f" --engine_worker_queue_port {str(self.cfg.engine_worker_queue_port)}"
1016+
f" --pod_ip {self.cfg.pod_ips[0]}"
10381017
f" --total_block_num {self.cfg.cache_config.total_block_num}"
10391018
f" --block_size {self.cfg.cache_config.block_size}"
10401019
f" --enc_dec_block_num {self.cfg.cache_config.enc_dec_block_num}"
@@ -1171,10 +1150,12 @@ def _stop_profile(self):
11711150
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
11721151
device_ids = self.cfg.device_ids.split(",")
11731152
self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager(
1174-
self.cfg.cache_config, self.cfg.tensor_parallel_size,
1175-
device_ids, self.cfg.engine_worker_queue_port,
1176-
self.ipc_signal_suffix)
1177-
1153+
cache_config=self.cfg.cache_config,
1154+
tensor_parallel_size=self.cfg.tensor_parallel_size,
1155+
device_ids=device_ids,
1156+
pod_ip=self.cfg.pod_ips[0],
1157+
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
1158+
pid_suffix=self.ipc_signal_suffix)
11781159
def check_health(self, time_interval_threashold=30):
11791160
"""
11801161
Check the health of the model server by checking whether all workers are alive.
@@ -1254,3 +1235,34 @@ def detect_thread():
12541235
except Exception:
12551236
pass
12561237
return True
1238+
1239+
def start_queue_service(self):
1240+
"""
1241+
start queue service for engine worker communication
1242+
"""
1243+
address = (self.cfg.pod_ips[0], self.cfg.engine_worker_queue_port)
1244+
if self.cfg.host_ip == self.cfg.pod_ips[0] or self.cfg.pod_ips[0] == "0.0.0.0":
1245+
self.engine_worker_queue_server = EngineWorkerQueue(
1246+
address=address,
1247+
is_server=True,
1248+
num_client=self.cfg.tensor_parallel_size,
1249+
local_data_parallel_size=self.cfg.parallel_config.
1250+
data_parallel_size)
1251+
1252+
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != 'mixed':
1253+
self.cache_task_queue = EngineCacheQueue(
1254+
address=(self.cfg.pod_ips[0], self.cfg.cache_config.cache_queue_port),
1255+
authkey=b'cache_queue_service',
1256+
is_server=True,
1257+
num_client=self.cfg.tensor_parallel_size,
1258+
client_id=-1,
1259+
local_data_parallel_size=self.cfg.parallel_config.
1260+
data_parallel_size)
1261+
1262+
1263+
self.engine_worker_queue = EngineWorkerQueue(
1264+
address=address,
1265+
is_server=False,
1266+
num_client=self.cfg.tensor_parallel_size,
1267+
client_id=0,
1268+
local_data_parallel_id=0)

fastdeploy/engine/expert_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(self, cfg, local_data_parallel_id):
6565

6666
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
6767

68-
address = ('0.0.0.0', cfg.engine_worker_queue_port)
68+
address = (cfg.pod_ips[0], cfg.engine_worker_queue_port)
6969
self.engine_worker_queue = EngineWorkerQueue(
7070
address=address,
7171
is_server=False,

0 commit comments

Comments
 (0)