Skip to content

Commit 9368cc9

Browse files
Automatically bind CPU OMP Threads of a rank to CPU ids of a NUMA node. (vllm-project#17930)
Signed-off-by: Tsai, Louie <louie.tsai@intel.com> Co-authored-by: Li, Jiang <bigpyj64@gmail.com>
1 parent 32b3946 commit 9368cc9

File tree

6 files changed

+134
-8
lines changed

6 files changed

+134
-8
lines changed

docs/getting_started/installation/cpu.md

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,9 @@ vLLM CPU backend supports the following vLLM features:
110110

111111
## Related runtime environment variables
112112

113-
- `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.
114-
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads. For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores.
113+
- `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`.
114+
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads. For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node. By setting to `all`, the OpenMP threads of each rank uses all CPU cores available on the system. Default value is `auto`.
115+
- `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `0`.
115116
- `VLLM_CPU_MOE_PREPACK`: whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False).
116117

117118
## Performance tips
@@ -133,7 +134,15 @@ export VLLM_CPU_OMP_THREADS_BIND=0-29
133134
vllm serve facebook/opt-125m
134135
```
135136

136-
- If using vLLM CPU backend on a machine with hyper-threading, it is recommended to bind only one OpenMP thread on each physical CPU core using `VLLM_CPU_OMP_THREADS_BIND`. On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores:
137+
or using default auto thread binding:
138+
139+
```console
140+
export VLLM_CPU_KVCACHE_SPACE=40
141+
export VLLM_CPU_NUM_OF_RESERVED_CPU=2
142+
vllm serve facebook/opt-125m
143+
```
144+
145+
- If using vLLM CPU backend on a machine with hyper-threading, it is recommended to bind only one OpenMP thread on each physical CPU core using `VLLM_CPU_OMP_THREADS_BIND` or using auto thread binding feature by default. On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores:
137146

138147
```console
139148
$ lscpu -e # check the mapping between logical CPU cores and physical CPU cores
@@ -178,6 +187,12 @@ $ python examples/offline_inference/basic/basic.py
178187
VLLM_CPU_KVCACHE_SPACE=40 VLLM_CPU_OMP_THREADS_BIND="0-31|32-63" vllm serve meta-llama/Llama-2-7b-chat-hf -tp=2 --distributed-executor-backend mp
179188
```
180189

190+
or using default auto thread binding:
191+
192+
```console
193+
VLLM_CPU_KVCACHE_SPACE=40 vllm serve meta-llama/Llama-2-7b-chat-hf -tp=2 --distributed-executor-backend mp
194+
```
195+
181196
- For each thread id list in `VLLM_CPU_OMP_THREADS_BIND`, users should guarantee threads in the list belong to a same NUMA node.
182197

183198
- Meanwhile, users should also take care of memory capacity of each NUMA node. The memory usage of each TP rank is the sum of `weight shard size` and `VLLM_CPU_KVCACHE_SPACE`, if it exceeds the capacity of a single NUMA node, TP worker will be killed due to out-of-memory.

requirements/cpu.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,5 @@ triton==3.2.0; platform_machine == "x86_64"
2727
# Intel Extension for PyTorch, only for x86_64 CPUs
2828
intel-openmp==2024.2.1; platform_machine == "x86_64"
2929
intel_extension_for_pytorch==2.7.0; platform_machine == "x86_64"
30+
py-libnuma; platform_system != "Darwin"
31+
psutil; platform_system != "Darwin"

vllm/envs.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
VLLM_PP_LAYER_PARTITION: Optional[str] = None
4545
VLLM_CPU_KVCACHE_SPACE: int = 0
4646
VLLM_CPU_OMP_THREADS_BIND: str = ""
47+
VLLM_CPU_NUM_OF_RESERVED_CPU: int = 0
4748
VLLM_CPU_MOE_PREPACK: bool = True
4849
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
4950
VLLM_XLA_CHECK_RECOMPILATION: bool = False
@@ -422,7 +423,12 @@ def get_vllm_port() -> Optional[int]:
422423
# (CPU backend only) CPU core ids bound by OpenMP threads, e.g., "0-31",
423424
# "0,1,2", "0-31,33". CPU cores of different ranks are separated by '|'.
424425
"VLLM_CPU_OMP_THREADS_BIND":
425-
lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "all"),
426+
lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "auto"),
427+
428+
# (CPU backend only) CPU cores not used by OMP threads .
429+
# Those CPU cores will not be used by OMP threads of a rank.
430+
"VLLM_CPU_NUM_OF_RESERVED_CPU":
431+
lambda: int(os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU", "0")),
426432

427433
# (CPU backend only) whether to use prepack for MoE layer. This will be
428434
# passed to ipex.llm.modules.GatedMLPMOE. On unsupported CPUs, you might

vllm/platforms/cpu.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
208208
# Disable torch async compiling which won't work with daemonic processes
209209
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
210210

211+
# Share the cpusets list among ranks by spawning process instead
212+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
213+
211214
# Intel OpenMP setting
212215
ld_prealod_str = os.getenv("LD_PRELOAD", "")
213216
if "libiomp5.so" in ld_prealod_str:

vllm/v1/worker/cpu_worker.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import os
3+
from importlib import util
34
from typing import Optional
45

56
import torch
@@ -38,10 +39,14 @@ def __init__(self,
3839
def init_device(self):
3940
# Setup OpenMP threads affinity.
4041
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
41-
if omp_cpuids == "all":
42-
self.local_omp_cpuid = "all"
42+
self.local_omp_cpuid = "all"
43+
if omp_cpuids == "auto":
44+
self.local_omp_cpuid = self.get_cpus_id_binding_based_on_numa_nodes(
45+
)
4346
else:
4447
self.local_omp_cpuid = omp_cpuids.split("|")[self.rank]
48+
49+
if self.local_omp_cpuid != "all":
4550
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
4651
if ret:
4752
logger.info(ret)
@@ -99,3 +104,49 @@ def execute_model(
99104

100105
assert isinstance(output, ModelRunnerOutput)
101106
return output if self.is_driver_worker else None
107+
108+
def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
109+
"""Return CPUs id binding based on NUMA nodes.
110+
"""
111+
rank_to_cpus = self.local_omp_cpuid
112+
# Setup OpenMP thread affinity based on NUMA nodes automatically
113+
world_size = self.vllm_config.parallel_config.world_size
114+
libnuma_found = util.find_spec("numa") is not None
115+
psutil_found = util.find_spec("psutil") is not None
116+
if libnuma_found and psutil_found:
117+
import psutil
118+
from numa import info
119+
cpu_count = psutil.cpu_count(logical=False)
120+
cpus_allow_list = psutil.Process().cpu_affinity()
121+
numa_size = info.get_num_configured_nodes()
122+
cpu_count_per_numa = cpu_count // numa_size
123+
num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
124+
cpu_count_per_numa // 2)
125+
126+
# check allow node_to_cpus list
127+
node_to_cpus = []
128+
for i in range(numa_size):
129+
node_intersect = set(
130+
info.node_to_cpus(i)).intersection(cpus_allow_list)
131+
if bool(node_intersect):
132+
node_to_cpus.append(list(node_intersect))
133+
134+
if world_size > len(node_to_cpus):
135+
logger.error(
136+
"Auto thread-binding failed due to "
137+
"world size: %d is larger than "
138+
"allowed NUMA nodes number: %d."
139+
"Please try to bind threads manually.", world_size,
140+
len(node_to_cpus))
141+
else:
142+
end = cpu_count_per_numa - num_of_reserved_cpu
143+
rank_to_cpus_list = node_to_cpus[self.rank][:end]
144+
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
145+
logger.info("auto thread-binding list: %s", rank_to_cpus)
146+
else:
147+
logger.warning(
148+
"Auto thread-binding is not supported due to "
149+
"the lack of package numa and psutil,"
150+
"fallback to no thread-binding. To get better performance,"
151+
"please try to manually bind threads.")
152+
return rank_to_cpus

vllm/worker/cpu_worker.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""A CPU worker class."""
44
import os
5+
from importlib import util
56
from typing import Dict, List, Optional, Set, Tuple, Type
67

78
import torch
@@ -156,8 +157,10 @@ def __init__(
156157

157158
# Setup OpenMP threads affinity.
158159
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
159-
if omp_cpuids == "all":
160-
self.local_omp_cpuid = "all"
160+
self.local_omp_cpuid = "all"
161+
if omp_cpuids == "auto":
162+
self.local_omp_cpuid = self.get_cpus_id_binding_based_on_numa_nodes(
163+
)
161164
else:
162165
self.local_omp_cpuid = omp_cpuids.split("|")[rank]
163166

@@ -399,3 +402,49 @@ def get_cache_block_size_bytes(self) -> int:
399402
return CPUCacheEngine.get_cache_block_size(
400403
self.cache_config.block_size, self.cache_config.cache_dtype,
401404
self.model_config, self.parallel_config)
405+
406+
def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
407+
"""Return CPUs id binding based on NUMA nodes.
408+
"""
409+
rank_to_cpus = self.local_omp_cpuid
410+
# Setup OpenMP thread affinity based on NUMA nodes automatically
411+
world_size = self.vllm_config.parallel_config.world_size
412+
libnuma_found = util.find_spec("numa") is not None
413+
psutil_found = util.find_spec("psutil") is not None
414+
if libnuma_found and psutil_found:
415+
import psutil
416+
from numa import info
417+
cpu_count = psutil.cpu_count(logical=False)
418+
cpus_allow_list = psutil.Process().cpu_affinity()
419+
numa_size = info.get_num_configured_nodes()
420+
cpu_count_per_numa = cpu_count // numa_size
421+
num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
422+
cpu_count_per_numa // 2)
423+
424+
# check allow node_to_cpus list
425+
node_to_cpus = []
426+
for i in range(numa_size):
427+
node_intersect = set(
428+
info.node_to_cpus(i)).intersection(cpus_allow_list)
429+
if bool(node_intersect):
430+
node_to_cpus.append(list(node_intersect))
431+
432+
if world_size > len(node_to_cpus):
433+
logger.error(
434+
"Auto thread-binding failed due to "
435+
"world size: %d is larger than "
436+
"allowed NUMA nodes number: %d."
437+
"Please try to bind threads manually.", world_size,
438+
len(node_to_cpus))
439+
else:
440+
end = cpu_count_per_numa - num_of_reserved_cpu
441+
rank_to_cpus_list = node_to_cpus[self.rank][:end]
442+
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
443+
logger.info("auto thread-binding list: %s", rank_to_cpus)
444+
else:
445+
logger.warning(
446+
"Auto thread-binding is not supported due to "
447+
"the lack of package numa and psutil,"
448+
"fallback to no thread-binding. To get better performance,"
449+
"please try to manually bind threads.")
450+
return rank_to_cpus

0 commit comments

Comments
 (0)