Skip to content

Commit ac9fb73

Browse files
authored
On environments where numa cannot be detected we get 0 (#21115)
Signed-off-by: Eric Curtin <ecurtin@redhat.com>
1 parent a3a6c69 commit ac9fb73

File tree

1 file changed

+111
-77
lines changed

1 file changed

+111
-77
lines changed

vllm/v1/worker/cpu_worker.py

Lines changed: 111 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,20 @@
1313
from vllm.model_executor.utils import set_random_seed
1414
from vllm.platforms import CpuArchEnum, current_platform
1515
from vllm.sequence import IntermediateTensors
16+
from vllm.utils import PlaceholderModule
1617
from vllm.v1.core.sched.output import SchedulerOutput
1718
from vllm.v1.outputs import ModelRunnerOutput
1819
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
1920
from vllm.v1.worker.gpu_worker import (Worker,
2021
init_worker_distributed_environment)
2122

23+
try:
24+
import psutil
25+
from numa import info
26+
except ImportError:
27+
psutil = PlaceholderModule("psutil") # type: ignore[assignment]
28+
numa = PlaceholderModule("numa") # type: ignore[assignment]
29+
2230
logger = init_logger(__name__)
2331

2432

@@ -37,6 +45,8 @@ def __init__(self,
3745
is_driver_worker=is_driver_worker)
3846

3947
self.parallel_config.disable_custom_all_reduce = True
48+
self.manually_bind_threads_suggestion = (
49+
"To get better performance, please try to manually bind threads.")
4050

4151
def init_device(self):
4252
# Setup OpenMP threads affinity.
@@ -112,50 +122,111 @@ def execute_model(
112122
assert isinstance(output, ModelRunnerOutput)
113123
return output if self.is_driver_worker else None
114124

125+
def warn_inability_to_detect_numa(self) -> None:
126+
logger.warning(
127+
"Auto thread-binding failed due to the "
128+
"inability to detect numa nodes. %s",
129+
self.manually_bind_threads_suggestion)
130+
131+
def warn_lack_of_numa_and_psutil(self) -> None:
132+
logger.warning(
133+
"Auto thread-binding failed due to "
134+
"the lack of package numa and psutil. %s",
135+
self.manually_bind_threads_suggestion)
136+
137+
def warn_world_size_too_large(self, world_size: int,
138+
node_to_cpus_len: int) -> None:
139+
logger.warning(
140+
"Auto thread-binding failed due to "
141+
"world size: %d being larger than "
142+
"allowed NUMA nodes number: %d. %s", world_size, node_to_cpus_len,
143+
self.manually_bind_threads_suggestion)
144+
145+
def get_cpus_allow_list_and_numa_size(self):
146+
cpus_allow_list = psutil.Process().cpu_affinity()
147+
numa_size = info.get_num_configured_nodes()
148+
return cpus_allow_list, numa_size
149+
150+
def auto_thread_binding_based_on_numa_nodes(self, world_size: int,
151+
rank_to_cpus: str) -> str:
152+
cpu_count = psutil.cpu_count(logical=False)
153+
cpus_allow_list, numa_size = self.get_cpus_allow_list_and_numa_size()
154+
if not numa_size:
155+
self.warn_inability_to_detect_numa()
156+
return rank_to_cpus
157+
158+
cpu_count_per_numa = cpu_count // numa_size
159+
num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
160+
cpu_count_per_numa // 2)
161+
162+
node_to_cpus = []
163+
for i in range(numa_size):
164+
node_intersect = set(
165+
info.node_to_cpus(i)).intersection(cpus_allow_list)
166+
if bool(node_intersect):
167+
node_to_cpus.append(list(node_intersect))
168+
169+
node_to_cpus_len = len(node_to_cpus)
170+
if world_size > node_to_cpus_len:
171+
self.warn_world_size_too_large(world_size, node_to_cpus_len)
172+
else:
173+
end = cpu_count_per_numa - num_of_reserved_cpu
174+
rank_to_cpus_list = node_to_cpus[self.rank][:end]
175+
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
176+
logger.info("auto thread-binding list: %s", rank_to_cpus)
177+
return rank_to_cpus
178+
179+
def libnuma_and_psutil_found(self) -> bool:
180+
libnuma_found = util.find_spec("numa") is not None
181+
psutil_found = util.find_spec("psutil") is not None
182+
183+
return libnuma_found and psutil_found
184+
115185
def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
116186
"""Return CPUs id binding based on NUMA nodes.
117187
"""
118188
rank_to_cpus = self.local_omp_cpuid
119189
# Setup OpenMP thread affinity based on NUMA nodes automatically
120190
world_size = self.vllm_config.parallel_config.world_size
121-
libnuma_found = util.find_spec("numa") is not None
122-
psutil_found = util.find_spec("psutil") is not None
123-
if libnuma_found and psutil_found:
124-
import psutil
125-
from numa import info
126-
cpu_count = psutil.cpu_count(logical=False)
127-
cpus_allow_list = psutil.Process().cpu_affinity()
128-
numa_size = info.get_num_configured_nodes()
129-
cpu_count_per_numa = cpu_count // numa_size
130-
num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
131-
cpu_count_per_numa // 2)
191+
if self.libnuma_and_psutil_found():
192+
rank_to_cpus = self.auto_thread_binding_based_on_numa_nodes(
193+
world_size, rank_to_cpus)
194+
else:
195+
self.warn_lack_of_numa_and_psutil()
196+
return rank_to_cpus
132197

133-
# check allow node_to_cpus list
134-
node_to_cpus = []
135-
for i in range(numa_size):
136-
node_intersect = set(
137-
info.node_to_cpus(i)).intersection(cpus_allow_list)
138-
if bool(node_intersect):
139-
node_to_cpus.append(list(node_intersect))
140-
141-
if world_size > len(node_to_cpus):
142-
logger.error(
143-
"Auto thread-binding failed due to "
144-
"world size: %d is larger than "
145-
"allowed NUMA nodes number: %d."
146-
"Please try to bind threads manually.", world_size,
147-
len(node_to_cpus))
148-
else:
149-
end = cpu_count_per_numa - num_of_reserved_cpu
150-
rank_to_cpus_list = node_to_cpus[self.rank][:end]
151-
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
152-
logger.info("auto thread-binding list: %s", rank_to_cpus)
198+
def select_threads_per_power_core(self,
199+
node_cpu_ids: list[int]) -> list[int]:
200+
return [cpu for cpu in node_cpu_ids if cpu % 8 < 4]
201+
202+
def auto_thread_binding_based_on_numa_nodes_ppc64le(
203+
self, world_size: int, rank_to_cpus: str) -> str:
204+
cpus_allow_list, numa_size = self.get_cpus_allow_list_and_numa_size()
205+
if not numa_size:
206+
self.warn_inability_to_detect_numa()
207+
return rank_to_cpus
208+
209+
node_to_cpus = []
210+
for i in range(numa_size):
211+
node_intersect = set(
212+
info.node_to_cpus(i)).intersection(cpus_allow_list)
213+
if bool(node_intersect):
214+
node_to_cpus.append(sorted(list(node_intersect)))
215+
216+
node_to_cpus_len = len(node_to_cpus)
217+
if world_size > node_to_cpus_len:
218+
self.warn_world_size_too_large(world_size, node_to_cpus_len)
153219
else:
154-
logger.warning(
155-
"Auto thread-binding is not supported due to "
156-
"the lack of package numa and psutil,"
157-
"fallback to no thread-binding. To get better performance,"
158-
"please try to manually bind threads.")
220+
node_cpus_this_rank = node_to_cpus[self.rank]
221+
node_cpus_this_rank = self.select_threads_per_power_core(
222+
node_cpus_this_rank)
223+
cpu_count_per_numa = len(node_cpus_this_rank)
224+
num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
225+
cpu_count_per_numa // 2)
226+
end = cpu_count_per_numa - num_of_reserved_cpu
227+
rank_to_cpus_list = node_cpus_this_rank[:end]
228+
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
229+
logger.info("ppc64le thread-binding list: %s", rank_to_cpus)
159230
return rank_to_cpus
160231

161232
def get_cpus_id_binding_based_on_numa_nodes_ppc64le(self) -> str:
@@ -166,48 +237,11 @@ def get_cpus_id_binding_based_on_numa_nodes_ppc64le(self) -> str:
166237
performance by avoiding oversubscription of logical CPUs on Power.
167238
"""
168239

169-
def select_threads_per_power_core(node_cpu_ids):
170-
return [cpu for cpu in node_cpu_ids if cpu % 8 < 4]
171-
172240
rank_to_cpus = self.local_omp_cpuid
173241
world_size = self.vllm_config.parallel_config.world_size
174-
libnuma_found = util.find_spec("numa") is not None
175-
psutil_found = util.find_spec("psutil") is not None
176-
if libnuma_found and psutil_found:
177-
import psutil
178-
from numa import info
179-
cpus_allow_list = psutil.Process().cpu_affinity()
180-
numa_size = info.get_num_configured_nodes()
181-
182-
node_to_cpus = []
183-
for i in range(numa_size):
184-
node_intersect = set(
185-
info.node_to_cpus(i)).intersection(cpus_allow_list)
186-
if bool(node_intersect):
187-
node_to_cpus.append(sorted(list(node_intersect)))
188-
189-
if world_size > len(node_to_cpus):
190-
logger.error(
191-
"Auto thread-binding failed due to "
192-
"world size: %d is larger than "
193-
"allowed NUMA nodes number: %d."
194-
"Please try to bind threads manually.", world_size,
195-
len(node_to_cpus))
196-
else:
197-
node_cpus_this_rank = node_to_cpus[self.rank]
198-
node_cpus_this_rank = select_threads_per_power_core(
199-
node_cpus_this_rank)
200-
cpu_count_per_numa = len(node_cpus_this_rank)
201-
num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
202-
cpu_count_per_numa // 2)
203-
end = cpu_count_per_numa - num_of_reserved_cpu
204-
rank_to_cpus_list = node_cpus_this_rank[:end]
205-
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
206-
logger.info("ppc64le thread-binding list: %s", rank_to_cpus)
242+
if self.libnuma_and_psutil_found():
243+
rank_to_cpus = self.auto_thread_binding_based_on_numa_nodes_ppc64le(
244+
world_size, rank_to_cpus)
207245
else:
208-
logger.warning(
209-
"Auto thread-binding is not supported due to "
210-
"the lack of package numa and psutil,"
211-
"fallback to no thread-binding. To get better performance,"
212-
"please try to manually bind threads.")
246+
self.warn_lack_of_numa_and_psutil()
213247
return rank_to_cpus

0 commit comments

Comments
 (0)