Skip to content

Commit 6db31e7

Browse files
Akashcodes732Akash Kaothalkarnikhil-arm
authored
[Hardware][PPC64LE] Enable V1 for ppc64le and ARM (#20554)
Signed-off-by: Akash Kaothalkar <akash.kaothalkar@ibm.com> Co-authored-by: Akash Kaothalkar <akash.kaothalkar@ibm.com> Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com>
1 parent 977180c commit 6db31e7

File tree

4 files changed

+77
-13
lines changed

4 files changed

+77
-13
lines changed

vllm/engine/arg_utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from vllm.executor.executor_base import ExecutorBase
3737
from vllm.logger import init_logger
3838
from vllm.model_executor.layers.quantization import QuantizationMethods
39+
from vllm.platforms import CpuArchEnum, current_platform
3940
from vllm.plugins import load_general_plugins
4041
from vllm.reasoning import ReasoningParserManager
4142
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
@@ -1096,7 +1097,6 @@ def create_engine_config(
10961097
If VLLM_USE_V1 is specified by the user but the VllmConfig
10971098
is incompatible, we raise an error.
10981099
"""
1099-
from vllm.platforms import current_platform
11001100
current_platform.pre_register_and_update()
11011101

11021102
device_config = DeviceConfig(
@@ -1123,9 +1123,16 @@ def create_engine_config(
11231123
# Set default arguments for V0 or V1 Engine.
11241124
if use_v1:
11251125
self._set_default_args_v1(usage_context, model_config)
1126+
# Disable chunked prefill for POWER (ppc64le)/ARM CPUs in V1
1127+
if current_platform.is_cpu(
1128+
) and current_platform.get_cpu_architecture() in (
1129+
CpuArchEnum.POWERPC, CpuArchEnum.ARM):
1130+
logger.info(
1131+
"Chunked prefill is not supported for ARM and POWER CPUs; "
1132+
"disabling it for V1 backend.")
1133+
self.enable_chunked_prefill = False
11261134
else:
11271135
self._set_default_args_v0(model_config)
1128-
11291136
assert self.enable_chunked_prefill is not None
11301137

11311138
if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]:
@@ -1242,7 +1249,6 @@ def create_engine_config(
12421249
if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
12431250
raise ValueError("Multi-Step Chunked-Prefill is not supported "
12441251
"for pipeline-parallel-size > 1")
1245-
from vllm.platforms import current_platform
12461252
if current_platform.is_cpu():
12471253
logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
12481254
"currently not supported for CPUs and has been "
@@ -1391,7 +1397,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
13911397
# Skip this check if we are running on a non-GPU platform,
13921398
# or if the device capability is not available
13931399
# (e.g. in a Ray actor without GPUs).
1394-
from vllm.platforms import current_platform
13951400
if (current_platform.is_cuda()
13961401
and current_platform.get_device_capability()
13971402
and current_platform.get_device_capability().major < 8):
@@ -1652,7 +1657,6 @@ def _set_default_args_v1(self, usage_context: UsageContext,
16521657
# as the platform that vLLM is running on (e.g. the case of scaling
16531658
# vLLM with Ray) and has no GPUs. In this case we use the default
16541659
# values for non-H100/H200 GPUs.
1655-
from vllm.platforms import current_platform
16561660
try:
16571661
device_memory = current_platform.get_device_total_memory()
16581662
device_name = current_platform.get_device_name().lower()
@@ -1755,7 +1759,6 @@ def add_cli_args(parser: FlexibleArgumentParser,
17551759
parser.add_argument('--disable-log-requests',
17561760
action='store_true',
17571761
help='Disable logging requests.')
1758-
from vllm.platforms import current_platform
17591762
current_platform.pre_register_and_update(parser)
17601763
return parser
17611764

vllm/platforms/cpu.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,5 +271,6 @@ def default_v1(cls, model_config) -> bool:
271271
"""Returns whether the current platform can use v1 by default for the
272272
supplied model configuration.
273273
"""
274-
return cls.supports_v1(
275-
model_config) and cls.get_cpu_architecture() == CpuArchEnum.X86
274+
arch = cls.get_cpu_architecture()
275+
return (cls.supports_v1(model_config) and arch
276+
in (CpuArchEnum.X86, CpuArchEnum.POWERPC, CpuArchEnum.ARM))

vllm/v1/attention/backends/cpu_attn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,6 @@ def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec,
316316
block_table: BlockTable) -> None:
317317
self.runner = runner
318318
self.block_table = block_table
319-
320319
# For reorder
321320
self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs,
322321
dtype=np.int64)
@@ -401,11 +400,14 @@ def build(self, common_prefix_len: int,
401400
num_prefill_tokens=num_prefill_tokens,
402401
num_decode_tokens=num_decode_tokens,
403402
slot_mapping=slot_mapping,
403+
# to ensure inference when chunked_prefill is disabled
404+
seq_lens=runner.seq_lens_cpu[:num_reqs].tolist(),
404405
seq_lens_tensor=runner.
405406
seq_lens_cpu[num_prompt_req:num_reqs], # decode
406407
max_decode_seq_len=max_decode_seq_len, # decode
407408
block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode
408-
chunked_prefill=True,
409+
chunked_prefill=self.runner.scheduler_config.
410+
chunked_prefill_enabled,
409411
max_query_len=max_query_len,
410412
max_kv_len=max_prefill_seq_len,
411413
prefill_query_start_loc=runner.

vllm/v1/worker/cpu_worker.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
1212
from vllm.logger import init_logger
1313
from vllm.model_executor.utils import set_random_seed
14-
from vllm.platforms import current_platform
14+
from vllm.platforms import CpuArchEnum, current_platform
1515
from vllm.sequence import IntermediateTensors
1616
from vllm.v1.core.sched.output import SchedulerOutput
1717
from vllm.v1.outputs import ModelRunnerOutput
@@ -43,8 +43,12 @@ def init_device(self):
4343
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
4444
self.local_omp_cpuid = "all"
4545
if omp_cpuids == "auto":
46-
self.local_omp_cpuid = self.get_cpus_id_binding_based_on_numa_nodes(
47-
)
46+
if current_platform.get_cpu_architecture() == CpuArchEnum.POWERPC:
47+
self.local_omp_cpuid = (
48+
self.get_cpus_id_binding_based_on_numa_nodes_ppc64le())
49+
else:
50+
self.local_omp_cpuid = (
51+
self.get_cpus_id_binding_based_on_numa_nodes())
4852
else:
4953
self.local_omp_cpuid = omp_cpuids.split("|")[self.rank]
5054

@@ -153,3 +157,57 @@ def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
153157
"fallback to no thread-binding. To get better performance,"
154158
"please try to manually bind threads.")
155159
return rank_to_cpus
160+
161+
def get_cpus_id_binding_based_on_numa_nodes_ppc64le(self) -> str:
162+
"""
163+
Power (ppc64le) specific: Selects a subset of threads per core for
164+
each NUMA node.This is robust to SMT mode (SMT-8, SMT-4, etc)
165+
because the OS only exposes available threads.This maximizes
166+
performance by avoiding oversubscription of logical CPUs on Power.
167+
"""
168+
169+
def select_threads_per_power_core(node_cpu_ids):
170+
return [cpu for cpu in node_cpu_ids if cpu % 8 < 4]
171+
172+
rank_to_cpus = self.local_omp_cpuid
173+
world_size = self.vllm_config.parallel_config.world_size
174+
libnuma_found = util.find_spec("numa") is not None
175+
psutil_found = util.find_spec("psutil") is not None
176+
if libnuma_found and psutil_found:
177+
import psutil
178+
from numa import info
179+
cpus_allow_list = psutil.Process().cpu_affinity()
180+
numa_size = info.get_num_configured_nodes()
181+
182+
node_to_cpus = []
183+
for i in range(numa_size):
184+
node_intersect = set(
185+
info.node_to_cpus(i)).intersection(cpus_allow_list)
186+
if bool(node_intersect):
187+
node_to_cpus.append(sorted(list(node_intersect)))
188+
189+
if world_size > len(node_to_cpus):
190+
logger.error(
191+
"Auto thread-binding failed due to "
192+
"world size: %d is larger than "
193+
"allowed NUMA nodes number: %d."
194+
"Please try to bind threads manually.", world_size,
195+
len(node_to_cpus))
196+
else:
197+
node_cpus_this_rank = node_to_cpus[self.rank]
198+
node_cpus_this_rank = select_threads_per_power_core(
199+
node_cpus_this_rank)
200+
cpu_count_per_numa = len(node_cpus_this_rank)
201+
num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
202+
cpu_count_per_numa // 2)
203+
end = cpu_count_per_numa - num_of_reserved_cpu
204+
rank_to_cpus_list = node_cpus_this_rank[:end]
205+
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
206+
logger.info("ppc64le thread-binding list: %s", rank_to_cpus)
207+
else:
208+
logger.warning(
209+
"Auto thread-binding is not supported due to "
210+
"the lack of package numa and psutil,"
211+
"fallback to no thread-binding. To get better performance,"
212+
"please try to manually bind threads.")
213+
return rank_to_cpus

0 commit comments

Comments
 (0)