Skip to content

[V1][Neuron] Neuron chunked prefill V1 impl #21490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions .buildkite/scripts/hardware_ci/run-neuron-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,6 @@ docker run --rm -it --device=/dev/neuron0 --network bridge \
${image_name} \
/bin/bash -c "
set -e; # Exit on first error
python3 /workspace/vllm/examples/offline_inference/neuron.py;
python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys;
for f in /workspace/vllm/tests/neuron/2_core/*.py; do
echo \"Running test file: \$f\";
python3 -m pytest \$f -v --capture=tee-sys;
done
python3 -m pytest /workspace/vllm/tests/neuron/2_core/test_chunked_prefill.py -v --capture=tee-sys;
"
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example is used to illustrate the usage when chunked prefill is enabled.
To run it, you need to set DISABLE_NEURON_CUSTOM_SCHEDULER=1 if the Neuron
plugin is installed.
"""

from neuronx_distributed_inference.models.config import OnDeviceSamplingConfig

from vllm import LLM, SamplingParams

model_path = "meta-llama/Llama-3.1-8B-Instruct"

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
(
"It is not the critic who counts; not the man who points out how the "
"strong man stumbles, or where the doer of deeds could have done them "
"better. The credit belongs to the man who is actually in the arena, "
"whose face is marred by dust and sweat and blood; who strives "
"valiantly; who errs, who comes short again and again, because there "
"is no effort without error and shortcoming; but who does actually "
"strive to do the deeds; who knows great enthusiasms, the great "
"devotions; who spends himself in a worthy cause; who at the best "
"knows"
),
(
"Do not go gentle into that good night, Old age should burn and rave "
"at close of day; Rage, rage against the dying of the light. Though "
"wise men at their end know dark is right, Because their words had "
"forked no lightning they Do not go gentle into that good night. Good "
"men, the last wave by, crying how bright Their frail deeds might have "
"danced in a green bay, Rage, rage against the dying of the light. "
"Wild men who caught and sang the sun in flight, And learn, too late, "
"they grieved it on its way, Do not go gentle into that good night. "
"Grave men, near death, who see with blinding sight Blind eyes could "
"blaze like meteors and be gay, Rage, rage against the dying of the "
"light. And you, my father, there on the sad height, Curse, bless, me "
"now with your fierce tears, I pray. Do not go gentle into that good "
"night. Rage, rage against the dying of the light."
),
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=30, top_k=1)

# Create an LLM.
llm = LLM(
model=model_path,
max_num_seqs=8,
max_model_len=1024,
max_num_batched_tokens=256, # chunk size
block_size=32,
tensor_parallel_size=32,
enable_prefix_caching=False,
enable_chunked_prefill=True,
override_neuron_config={
"is_block_kv_layout": True,
"sequence_parallel_enabled": True,
"on_device_sampling_config": OnDeviceSamplingConfig(),
"chunked_prefill_config": {
"max_num_seqs": 8,
"kernel_q_tile_size": 128,
"kernel_kv_tile_size": 4096,
},
"skip_warmup": True,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we test if this can be removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, inference still works correctly with warmup enabled, but I got very verbose error logs and warmup took 7 minutes. Thus keeping this line for now.

# chunked prefill currently only supports LNC=1.
"logical_nc_config": 1,
},
)

# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
79 changes: 79 additions & 0 deletions tests/neuron/2_core/test_chunked_prefill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm import LLM, SamplingParams


def test_v1_chunked_prefill():
model_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

llm = LLM(
model=model_path,
max_num_seqs=8,
max_model_len=512,
max_num_batched_tokens=128, # chunk size
block_size=32,
tensor_parallel_size=2,
enable_prefix_caching=False,
enable_chunked_prefill=True,
override_neuron_config={
"is_block_kv_layout": True,
"sequence_parallel_enabled": True,
"chunked_prefill_config": {
"max_num_seqs": 8,
"kernel_q_tile_size": 128,
"kernel_kv_tile_size": 4096,
},
"skip_warmup": True,
"save_sharded_checkpoint": True,
"logical_nc_config": 1,
},
)

prompts = [
"The president of the United States is",
"The capital of France is",
("It is not the critic who counts; not the man who points out how the "
"strong man stumbles, or where the doer of deeds could have done them "
"better. The credit belongs to the man who is actually in the arena, "
"whose face is marred by dust and sweat and blood; who strives "
"valiantly; who errs, who comes short again and again, because there "
"is no effort without error and shortcoming; but who does actually "
"strive to do the deeds; who knows great enthusiasms, the great "
"devotions; who spends himself in a worthy cause; who at the best "
"knows"),
("Do not go gentle into that good night, Old age should burn and rave "
"at close of day; Rage, rage against the dying of the light. Though "
"wise men at their end know dark is right, Because their words had "
"forked no lightning they Do not go gentle into that good night. Good "
"men, the last wave by, crying how bright Their frail deeds might have"
" danced in a green bay, Rage, rage against the dying of the light. "
"Wild men who caught and sang the sun in flight, And learn, too late, "
"they grieved it on its way, Do not go gentle into that good night. "
"Grave men, near death, who see with blinding sight Blind eyes could "
"blaze like meteors and be gay, Rage, rage against the dying of the "
"light. And you, my father, there on the sad height, Curse, bless, me "
"now with your fierce tears, I pray. Do not go gentle into that good "
"night. Rage, rage against the dying of the light."),
]
sampling_params = SamplingParams(max_tokens=30, top_k=1)

outputs = llm.generate(prompts, sampling_params)

expected_outputs = [

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are the outputs deterministic? doesn't we need to set the random seed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes the outputs should be deterministic since we're using greedy sampling

" a man named Donald Trump.\n\n2. B. The president of the United States"
" is a man named Donald Trump.\n\n3. C",
" Paris.\n\n2. B. The capital of France is Paris.\n\n3. C. The capital"
" of France is Paris.\n\n",
"ends the triumph of high achievement, and at worst, if he fails, at "
"least he fails while daring greatly, so that his place shall",
" Rage, rage against the dying of the light. Rage, rage against the "
"dying of the light. Rage, rage against"
]

for expected_output, output in zip(expected_outputs, outputs):
generated_text = output.outputs[0].text
print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}")
assert (expected_output == generated_text)

print("Neuron V1 chunked prefill test passed.")
26 changes: 18 additions & 8 deletions vllm/platforms/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .interface import Platform, PlatformEnum

if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
else:
VllmConfig = None

Expand Down Expand Up @@ -45,17 +45,23 @@ def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = \
"vllm.worker.neuron_worker.NeuronWorker"
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.v1.worker.neuron_worker.NeuronWorker"
# TODO: validate the config. for example, some configs
# must (e.g., block_size, enable_chunked_prefill,
# etc.) be set or provided with default values.
else:
parallel_config.worker_cls = \
"vllm.worker.neuron_worker.NeuronWorker"
if vllm_config.cache_config and vllm_config.model_config:
# neuron needs block_size = max_model_len
vllm_config.cache_config.block_size = \
vllm_config.model_config.max_model_len # type: ignore
Comment on lines +54 to +60

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this branch will need to be removed (in the PR or follow up) given that v0 is code paths will be gone

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer to leave v0 deletion to a separate PR for simplicity


if parallel_config.world_size > 1:
parallel_config.distributed_executor_backend = "uni"

if vllm_config.cache_config and vllm_config.model_config:
# neuron needs block_size = max_model_len
vllm_config.cache_config.block_size = \
vllm_config.model_config.max_model_len # type: ignore

if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
Expand Down Expand Up @@ -91,6 +97,10 @@ def is_neuronx_distributed_inference(cls) -> bool:
neuronx_distributed_inference = None
return neuronx_distributed_inference is not None

@classmethod
def supports_v1(cls, model_config: "ModelConfig") -> bool:
return True

@classmethod
@lru_cache
def is_transformers_neuronx(cls) -> bool:
Expand Down
117 changes: 117 additions & 0 deletions vllm/v1/worker/neuron_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A Neuron worker class."""
from typing import Optional

import torch.nn as nn

from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerBase

logger = init_logger(__name__)


class NeuronWorker(WorkerBase):
"""A worker class that executes the model on a group of neuron cores.
"""

def __init__(self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False) -> None:
super().__init__(vllm_config=vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker)

if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.device = self.device_config.device
self.model_runner = self.get_neuronx_distributed_model_runner(
vllm_config, self.device)

def init_device(self) -> None:
self.init_distributed_environment()

# Set random seed.
set_random_seed(self.model_config.seed)

def determine_available_memory(self):
# Note: This is not needed for Neuron, thus setting to 1GB as a
# placeholder. This will be implemented in the navtive integration
# phase
return 1024 * 1024 * 1024 # 1GB

def execute_model(
self, scheduler_output: "SchedulerOutput"
) -> Optional[ModelRunnerOutput]:
assert self.model_runner is not None
output = self.model_runner.execute_model(scheduler_output)
return output if self.is_driver_worker else None

def profile(self, is_start: bool = True):
raise NotImplementedError

def get_neuronx_distributed_model_runner(self, vllm_config, device):
from vllm.v1.worker.neuronx_distributed_model_runner import (
NeuronxDistributedModelRunner)
return NeuronxDistributedModelRunner(vllm_config=vllm_config,
device=device)

def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks

def load_model(self):
assert self.model_runner is not None
self.model_runner.load_model()

def compile_or_warm_up_model(self) -> None:
# Skip for NeuronX Distributed Inference
return None

def get_model(self) -> nn.Module:
raise NotImplementedError

def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
assert self.model_runner is not None
return self.model_runner.get_kv_cache_spec()

def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
assert self.model_runner is not None
self.model_runner.initialize_kv_cache(kv_cache_config)

def check_health(self) -> None:
# worker will always be healthy as long as it's running.
return

def init_distributed_environment(self):
"""
vLLM still needs the environment initialized when TP/PP > 1
"""
init_distributed_environment(
world_size=1,
rank=self.rank,
local_rank=self.local_rank,
distributed_init_method=self.distributed_init_method,
backend="gloo",
)

ensure_model_parallel_initialized(
1,
1,
)
Loading