Skip to content

Commit 5b03235

Browse files
authored
[Attention] MLA - Flashinfer Ragged Prefill (#20034)
1 parent 922f316 commit 5b03235

File tree

10 files changed

+422
-215
lines changed

10 files changed

+422
-215
lines changed

tests/v1/kv_connector/__init__.py

Whitespace-only changes.

tests/v1/kv_connector/unit/test_multi_connector.py

Lines changed: 15 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,10 @@
33
import filecmp
44
import shutil
55
import tempfile
6-
from collections import defaultdict
76
from pathlib import Path
87

98
from vllm import LLM, SamplingParams
10-
from vllm.config import KVTransferConfig, VllmConfig
11-
from vllm.distributed.kv_transfer.kv_connector.factory import (
12-
KVConnectorFactory)
13-
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
14-
SharedStorageConnector)
15-
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
9+
from vllm.config import KVTransferConfig
1610

1711
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
1812

@@ -25,65 +19,6 @@
2519
SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20)
2620

2721

28-
class TestSharedStorageConnector(SharedStorageConnector):
29-
30-
def __init__(self, config: VllmConfig, role):
31-
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
32-
self._connector = SharedStorageConnector(config, role)
33-
self.call_record: dict[str, int] = defaultdict(int)
34-
# Use a unique temp file per connector
35-
self._event_file = tempfile.gettempdir(
36-
) + f"/connector_{self.name}-{self.role.name}_events.log"
37-
# Start with an empty file
38-
with open(self._event_file, "w") as _:
39-
pass
40-
41-
def __getattribute__(self, name):
42-
if name in ("_connector", "call_record", "name", "_event_file",
43-
"__class__", "__dict__", "__getattribute__",
44-
"__init__"): # avoid recursion
45-
return object.__getattribute__(self, name)
46-
if not hasattr(self._connector, name):
47-
return object.__getattribute__(self, name)
48-
attr = getattr(self._connector, name)
49-
50-
# Intercept calls to the connector interface and write an event
51-
# for each one to a file, which can be read back in the main test proc.
52-
if callable(attr):
53-
54-
def wrapper(*args, **kwargs):
55-
self.call_record[name] += 1
56-
57-
# Include args that we're interested in
58-
to_log = [name]
59-
for arg in args:
60-
if isinstance(arg, int):
61-
to_log.append(str(arg))
62-
elif isinstance(arg, KVCacheBlocks):
63-
to_log.append(
64-
f"num_blocks={[len(b) for b in arg.blocks]}")
65-
66-
# Log the event as a line to the file
67-
try:
68-
with open(self._event_file, "a") as f:
69-
f.write(' '.join(to_log) + "\n")
70-
except Exception as e:
71-
print(f"[ERROR] Could not log event {name} "
72-
f"for {self.name}: {e}")
73-
return attr(*args, **kwargs)
74-
75-
return wrapper
76-
return attr
77-
78-
79-
# This relies on "fork" multiprocessing method being used.
80-
# It's the default but vLLM may fall back to spawn if for example CUDA
81-
# is already initialized.
82-
KVConnectorFactory.register_connector("TestSharedStorageConnector",
83-
TestSharedStorageConnector.__module__,
84-
TestSharedStorageConnector.__name__)
85-
86-
8722
# Helper function to compare directories recursively
8823
def _compare_directories(dir1: Path, dir2: Path) -> bool:
8924
"""Compares two directories recursively for identical content."""
@@ -118,19 +53,27 @@ def test_multi_shared_storage_connector_consistency():
11853
kv_role="kv_both",
11954
kv_connector_extra_config={
12055
"connectors": [{
121-
"kv_connector": "TestSharedStorageConnector",
122-
"kv_role": "kv_both",
56+
"kv_connector":
57+
"TestSharedStorageConnector",
58+
"kv_role":
59+
"kv_both",
12360
"kv_connector_extra_config": {
12461
"shared_storage_path": str(storage_1_path),
12562
"name": "storage1",
126-
}
63+
},
64+
"kv_connector_module_path":
65+
"tests.v1.kv_connector.unit.utils",
12766
}, {
128-
"kv_connector": "TestSharedStorageConnector",
129-
"kv_role": "kv_both",
67+
"kv_connector":
68+
"TestSharedStorageConnector",
69+
"kv_role":
70+
"kv_both",
13071
"kv_connector_extra_config": {
13172
"shared_storage_path": str(storage_2_path),
13273
"name": "storage2",
133-
}
74+
},
75+
"kv_connector_module_path":
76+
"tests.v1.kv_connector.unit.utils",
13477
}]
13578
},
13679
)

tests/v1/kv_connector/unit/utils.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import tempfile
4+
from collections import defaultdict
35
from typing import Any, Optional
46

57
import torch
68

79
from vllm import SamplingParams
810
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
911
ModelConfig, SchedulerConfig, VllmConfig)
12+
from vllm.distributed.kv_transfer.kv_connector.factory import (
13+
KVConnectorFactory)
14+
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
15+
SharedStorageConnector)
16+
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
1017
from vllm.v1.core.sched.scheduler import Scheduler
1118
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
1219
KVCacheGroupSpec)
@@ -187,3 +194,58 @@ def create_model_runner_output(
187194
finished_sending=finished_sending,
188195
finished_recving=finished_recving,
189196
)
197+
198+
199+
class TestSharedStorageConnector(SharedStorageConnector):
200+
201+
def __init__(self, config: VllmConfig, role):
202+
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
203+
self._connector = SharedStorageConnector(config, role)
204+
self.call_record: dict[str, int] = defaultdict(int)
205+
# Use a unique temp file per connector
206+
self._event_file = tempfile.gettempdir(
207+
) + f"/connector_{self.name}-{self.role.name}_events.log"
208+
# Start with an empty file
209+
with open(self._event_file, "w") as _:
210+
pass
211+
212+
def __getattribute__(self, name):
213+
if name in ("_connector", "call_record", "name", "_event_file",
214+
"__class__", "__dict__", "__getattribute__",
215+
"__init__"): # avoid recursion
216+
return object.__getattribute__(self, name)
217+
if not hasattr(self._connector, name):
218+
return object.__getattribute__(self, name)
219+
attr = getattr(self._connector, name)
220+
221+
# Intercept calls to the connector interface and write an event
222+
# for each one to a file, which can be read back in the main test proc.
223+
if callable(attr):
224+
225+
def wrapper(*args, **kwargs):
226+
self.call_record[name] += 1
227+
228+
# Include args that we're interested in
229+
to_log = [name]
230+
for arg in args:
231+
if isinstance(arg, int):
232+
to_log.append(str(arg))
233+
elif isinstance(arg, KVCacheBlocks):
234+
to_log.append(
235+
f"num_blocks={[len(b) for b in arg.blocks]}")
236+
237+
# Log the event as a line to the file
238+
try:
239+
with open(self._event_file, "a") as f:
240+
f.write(' '.join(to_log) + "\n")
241+
except Exception as e:
242+
print(f"[ERROR] Could not log event {name} "
243+
f"for {self.name}: {e}")
244+
return attr(*args, **kwargs)
245+
246+
return wrapper
247+
return attr
248+
249+
250+
KVConnectorFactory.register_connector("TestSharedStorageConnector", __name__,
251+
TestSharedStorageConnector.__name__)

vllm/attention/layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import vllm.envs as envs
1111
from vllm.attention import AttentionType
1212
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
13+
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
1314
from vllm.config import CacheConfig, get_current_vllm_config
1415
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
1516
has_kv_transfer_group,
@@ -21,7 +22,6 @@
2122
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
2223
from vllm.platforms import _Backend, current_platform
2324
from vllm.utils import direct_register_custom_op
24-
from vllm.v1.attention.backends.utils import validate_kv_sharing_target
2525

2626

2727
class Attention(nn.Module):
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
def validate_kv_sharing_target(current_layer_name, target_layer_name,
4+
static_forward_context):
5+
error_msg = (f"Specified KV sharing target layer for {current_layer_name} "
6+
f"is not valid: target layer {target_layer_name} ")
7+
8+
if current_layer_name == target_layer_name:
9+
raise ValueError(error_msg +
10+
"cannot be the same as the current layer.")
11+
12+
if target_layer_name not in static_forward_context:
13+
from vllm.model_executor.models.utils import extract_layer_index
14+
15+
# If target layer name is not in the static fwd context, it means either
16+
# a) the target layer does not come BEFORE the current layer, or
17+
# b) the target layer is not an Attention layer that exists in the model
18+
current_layer_idx = extract_layer_index(current_layer_name)
19+
target_layer_idx = extract_layer_index(target_layer_name)
20+
if current_layer_idx <= target_layer_idx:
21+
raise ValueError(error_msg + "must come before the current layer.")
22+
else:
23+
raise ValueError(error_msg +
24+
"is not a valid Attention layer in the model.")
25+
26+
# Currently KV sharing is only supported between layers of the same type
27+
target_layer_attn_type = static_forward_context[
28+
target_layer_name].attn_type
29+
expected = static_forward_context[current_layer_name].attn_type
30+
if target_layer_attn_type != expected:
31+
raise ValueError(
32+
error_msg +
33+
f"must be the same type as the current layer ({expected}).")

vllm/logger.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@
5353
}
5454

5555

56+
@lru_cache
57+
def _print_debug_once(logger: Logger, msg: str, *args: Hashable) -> None:
58+
# Set the stacklevel to 2 to print the original caller's line info
59+
logger.debug(msg, *args, stacklevel=2)
60+
61+
5662
@lru_cache
5763
def _print_info_once(logger: Logger, msg: str, *args: Hashable) -> None:
5864
# Set the stacklevel to 2 to print the original caller's line info
@@ -74,6 +80,13 @@ class _VllmLogger(Logger):
7480
`intel_extension_for_pytorch.utils._logger`.
7581
"""
7682

83+
def debug_once(self, msg: str, *args: Hashable) -> None:
84+
"""
85+
As [`debug`][logging.Logger.debug], but subsequent calls with
86+
the same message are silently dropped.
87+
"""
88+
_print_debug_once(self, msg, *args)
89+
7790
def info_once(self, msg: str, *args: Hashable) -> None:
7891
"""
7992
As [`info`][logging.Logger.info], but subsequent calls with
@@ -132,6 +145,7 @@ def init_logger(name: str) -> _VllmLogger:
132145
logger = logging.getLogger(name)
133146

134147
methods_to_patch = {
148+
"debug_once": _print_debug_once,
135149
"info_once": _print_info_once,
136150
"warning_once": _print_warning_once,
137151
}

vllm/v1/attention/backends/flashinfer.py

Lines changed: 5 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
import vllm.envs as envs
1515
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1616
AttentionType)
17-
from vllm.attention.layer import Attention
18-
from vllm.config import VllmConfig, get_layers_from_vllm_config
1917
from vllm.logger import init_logger
2018
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
2119
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
2220
CommonAttentionMetadata,
23-
get_kv_cache_layout)
21+
PerLayerParameters,
22+
get_kv_cache_layout,
23+
get_per_layer_parameters,
24+
infer_global_hyperparameters)
2425
from vllm.v1.kv_cache_interface import AttentionSpec
2526
from vllm.v1.worker.block_table import BlockTable
2627

@@ -93,70 +94,6 @@ def get_kv_cache_stride_order() -> tuple[int, ...]:
9394
return stride_order
9495

9596

96-
@dataclass
97-
class PerLayerParameters:
98-
"""
99-
Currently, FlashInfer backend only support models in which all layers share
100-
the same values for the following hyperparameters.
101-
"""
102-
103-
window_left: int
104-
logits_soft_cap: Optional[float]
105-
sm_scale: float
106-
107-
108-
def get_per_layer_parameters(
109-
vllm_config: VllmConfig) -> dict[str, PerLayerParameters]:
110-
"""
111-
Scan all attention layers and determine some hyperparameters
112-
to use during `plan`.
113-
"""
114-
115-
layers = get_layers_from_vllm_config(vllm_config, Attention)
116-
per_layer_params: dict[str, PerLayerParameters] = {}
117-
118-
for key, layer in layers.items():
119-
impl = layer.impl
120-
assert isinstance(impl, FlashInferImpl)
121-
122-
# Infer hyperparameters from the attention layer
123-
window_size = impl.sliding_window
124-
window_left = window_size[0] if window_size is not None else -1
125-
logits_soft_cap = impl.logits_soft_cap
126-
sm_scale = impl.scale
127-
128-
per_layer_params[key] = PerLayerParameters(window_left,
129-
logits_soft_cap, sm_scale)
130-
131-
return per_layer_params
132-
133-
134-
def infer_global_hyperparameters(
135-
per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters:
136-
"""
137-
Currently, FlashInfer backend only support models in which all layers share
138-
the same values for the following hyperparameters:
139-
- `window_left`
140-
- `logits_soft_cap`
141-
- `sm_scale`
142-
143-
So this function asserts that all layers share the same values for these
144-
hyperparameters and returns the global values.
145-
"""
146-
147-
assert len(per_layer_params) > 0, "No attention layers found in the model."
148-
149-
param_sets = list(per_layer_params.values())
150-
global_params = param_sets[0]
151-
for params in param_sets:
152-
assert params == global_params, (
153-
"FlashInfer backend currently only supports models in which all "
154-
"layers share the same values for the following hyperparameters: "
155-
"`window_left`, `logits_soft_cap`, `sm_scale`.")
156-
157-
return global_params
158-
159-
16097
@dataclass
16198
class FlashInferMetadata:
16299

@@ -336,7 +273,7 @@ def _get_cascade_wrapper(self):
336273
def _plan(self, attn_metadata: FlashInferMetadata):
337274
if self.global_hyperparameters is None:
338275
self.global_hyperparameters = infer_global_hyperparameters(
339-
get_per_layer_parameters(self.vllm_config))
276+
get_per_layer_parameters(self.vllm_config, FlashInferImpl))
340277
if attn_metadata.use_cascade:
341278
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
342279
attn_metadata.cascade_wrapper.plan(

0 commit comments

Comments
 (0)