Skip to content

Commit 3cd2474

Browse files
committed
Address comments
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
1 parent 55ddaa0 commit 3cd2474

File tree

5 files changed

+24
-32
lines changed

5 files changed

+24
-32
lines changed

tests/v1/e2e/test_kv_sharing_skip_prefill.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from vllm import LLM, SamplingParams
1515
from vllm.compilation.backends import set_model_tag
16-
from vllm.compilation.decorators import (skip_torch_compile,
16+
from vllm.compilation.decorators import (ignore_torch_compile,
1717
support_torch_compile)
1818
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
1919
VllmConfig)
@@ -161,7 +161,7 @@ def forward(
161161
return hidden_states, residual
162162

163163

164-
@skip_torch_compile
164+
@ignore_torch_compile
165165
class Qwen2ModelWithKVSharing(Qwen2Model):
166166

167167
def __init__(self,
@@ -193,18 +193,17 @@ def __init__(self,
193193
)
194194

195195
# Pre-allocate static buffers for CUDA graph
196-
self.max_num_tokens =\
197-
vllm_config.scheduler_config.max_num_batched_tokens
198-
self.dtype = vllm_config.model_config.dtype
199-
self.device = next(self.parameters()).device
200-
self.hidden_size = vllm_config.model_config.get_hidden_size()
201-
self.residual = torch.zeros((self.max_num_tokens, self.hidden_size),
202-
dtype=self.dtype,
203-
device=self.device)
196+
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
197+
dtype = vllm_config.model_config.dtype
198+
device = next(self.parameters()).device
199+
hidden_size = vllm_config.model_config.get_hidden_size()
200+
self.residual = torch.zeros((max_num_tokens, hidden_size),
201+
dtype=dtype,
202+
device=device)
204203
self.hidden_states = torch.zeros(
205-
(self.max_num_tokens, self.hidden_size),
206-
dtype=self.dtype,
207-
device=self.device)
204+
(max_num_tokens, hidden_size),
205+
dtype=dtype,
206+
device=device)
208207

209208
def forward(
210209
self,
@@ -355,8 +354,7 @@ def test_kv_sharing_skip_prefill(
355354
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
356355
compilation_config = CompilationConfig(
357356
level=CompilationLevel.PIECEWISE
358-
if not enforce_eager else CompilationLevel.NO_COMPILATION,
359-
cudagraph_share_memory_pool=False)
357+
if not enforce_eager else CompilationLevel.NO_COMPILATION)
360358

361359
with monkeypatch.context() as m:
362360
m.setenv("VLLM_USE_V1", "1")

vllm/compilation/backends.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -412,11 +412,8 @@ def __init__(
412412
# them, e.g. backbone (default), eagle_head, etc.
413413
self.prefix = prefix or model_tag
414414

415-
if vllm_config.compilation_config.cudagraph_share_memory_pool:
416-
global global_graph_pool
417-
if global_graph_pool is None:
418-
global_graph_pool = current_platform.graph_pool_handle()
419-
else:
415+
global global_graph_pool
416+
if global_graph_pool is None:
420417
global_graph_pool = current_platform.graph_pool_handle()
421418

422419
# TODO: in the future, if we want to use multiple

vllm/compilation/decorators.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
_T = TypeVar("_T", bound=type[nn.Module])
2424

2525

26-
def skip_torch_compile(cls: _T) -> _T:
27-
cls._skip_compile_vllm = True
26+
def ignore_torch_compile(cls: _T) -> _T:
27+
cls._ignore_compile_vllm = True
2828
return cls
2929

3030

@@ -161,7 +161,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
161161
self.do_not_compile = \
162162
vllm_config.compilation_config.level in [
163163
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
164-
] or not supports_dynamo() or getattr(self, "_skip_compile_vllm", False)
164+
] or not supports_dynamo() or getattr(self, "_ignore_compile_vllm", False)
165165
if self.do_not_compile:
166166
return
167167
compilation_counter.num_models_seen += 1

vllm/config.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4118,11 +4118,6 @@ class CompilationConfig:
41184118
"""Sizes to capture cudagraph.
41194119
- None (default): capture sizes are inferred from vllm config.
41204120
- list[int]: capture sizes are specified as given."""
4121-
cudagraph_share_memory_pool: bool = True
4122-
"""Whether to share a single global memory pool for each graph capture
4123-
When CUDA graphs are not replayed in the same order they are captured,
4124-
e.g. when compiling multiple modules in a model and modules take different
4125-
input shapes, it is unsafe to share memory across graph captures."""
41264121
cudagraph_copy_inputs: bool = False
41274122
"""Whether to copy input tensors for
41284123
cudagraph. If the caller can guarantee that the same input buffers

vllm/v1/worker/gpu_model_runner.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -317,9 +317,11 @@ def __init__(
317317
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
318318
self.shared_kv_cache_layers: dict[str, str] = {}
319319

320-
self.decode_indices = torch.zeros(self.max_num_tokens,
321-
dtype=torch.int32,
322-
device=self.device)
320+
self.decode_indices = None
321+
if self.cache_config.kv_sharing_skip_prefill:
322+
self.decode_indices = torch.zeros(self.max_num_tokens,
323+
dtype=torch.int32,
324+
device=self.device)
323325

324326
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
325327
"""
@@ -583,7 +585,7 @@ def _calc_decode_indices(self, logits_indices: torch.Tensor):
583585
"""
584586
Pads logits_indices to align with CUDA graph capture sizes
585587
"""
586-
if not self.cache_config.kv_sharing_skip_prefill:
588+
if self.decode_indices is None:
587589
return None
588590

589591
num_decode_reqs = 0

0 commit comments

Comments
 (0)