3
3
4
4
import gc
5
5
from collections .abc import Iterable
6
- from typing import Optional , Union
6
+ from typing import List , Optional , Union
7
7
8
8
import pytest
9
9
import torch
10
10
from torch import nn
11
11
from transformers import Qwen2Config
12
12
13
13
from vllm import LLM , SamplingParams
14
- from vllm .config import CacheConfig , VllmConfig
14
+ from vllm .compilation .backends import set_model_tag
15
+ from vllm .compilation .decorators import (skip_torch_compile ,
16
+ support_torch_compile )
17
+ from vllm .config import (CacheConfig , CompilationConfig , CompilationLevel ,
18
+ VllmConfig )
15
19
from vllm .forward_context import get_forward_context
16
20
from vllm .model_executor .layers .layernorm import RMSNorm
17
21
from vllm .model_executor .layers .logits_processor import LogitsProcessor
@@ -52,6 +56,7 @@ def __init__(
52
56
target_layer_idx = layer_idx % 5
53
57
kv_sharing_target_layer_name = f"{ attn_prefix } .attn" .replace (
54
58
str (layer_idx ), str (target_layer_idx ))
59
+
55
60
self .self_attn = Qwen2Attention (
56
61
hidden_size = self .hidden_size ,
57
62
num_heads = config .num_attention_heads ,
@@ -99,8 +104,72 @@ def forward(
99
104
return hidden_states , residual
100
105
101
106
107
+ @support_torch_compile
108
+ class DecoderLayerGroup (nn .Module ):
109
+
110
+ def __init__ (
111
+ self ,
112
+ * ,
113
+ vllm_config : VllmConfig ,
114
+ prefix : str = "" ,
115
+ layers : List [nn .Module ],
116
+ ):
117
+ super ().__init__ ()
118
+ self .layers = layers
119
+
120
+ def forward (
121
+ self ,
122
+ positions : torch .Tensor ,
123
+ hidden_states : torch .Tensor ,
124
+ residual : Optional [torch .Tensor ] = None ,
125
+ ):
126
+ for layer in self .layers :
127
+ hidden_states , residual = layer (
128
+ positions ,
129
+ hidden_states ,
130
+ residual ,
131
+ )
132
+ return hidden_states , residual
133
+
134
+
135
+ @skip_torch_compile
102
136
class Qwen2ModelWithKVSharing (Qwen2Model ):
103
137
138
+ def __init__ (self ,
139
+ * ,
140
+ vllm_config : VllmConfig ,
141
+ prefix : str = "" ,
142
+ decoder_layer_type : type [
143
+ nn .Module ] = Qwen2DecoderLayerWithKVSharing ):
144
+ super ().__init__ (
145
+ vllm_config = vllm_config ,
146
+ prefix = prefix ,
147
+ decoder_layer_type = decoder_layer_type ,
148
+ )
149
+
150
+ with set_model_tag ("first_layer_group" ):
151
+ self .first_layer_group = DecoderLayerGroup (
152
+ vllm_config = vllm_config ,
153
+ prefix = f"{ prefix } .first_layer_group" ,
154
+ layers = self .layers [self .start_layer :START_KV_SHARING_LAYER ],
155
+ )
156
+
157
+ with set_model_tag ("second_layer_group" ):
158
+ self .second_layer_group = DecoderLayerGroup (
159
+ vllm_config = vllm_config ,
160
+ prefix = f"{ prefix } .second_layer_group" ,
161
+ layers = self .layers [START_KV_SHARING_LAYER :self .end_layer ],
162
+ )
163
+
164
+ # Pre-allocate static buffers for CUDA graph
165
+ self .max_num_tokens = vllm_config .scheduler_config .max_num_batched_tokens
166
+ self .dtype = vllm_config .model_config .dtype
167
+ self .device = next (self .parameters ()).device
168
+ self .hidden_size = vllm_config .model_config .get_hidden_size ()
169
+ self .residual = torch .zeros ((self .max_num_tokens , self .hidden_size ),
170
+ dtype = self .dtype ,
171
+ device = self .device )
172
+
104
173
def forward (
105
174
self ,
106
175
input_ids : torch .Tensor ,
@@ -112,46 +181,40 @@ def forward(
112
181
hidden_states = inputs_embeds
113
182
else :
114
183
hidden_states = self .get_input_embeddings (input_ids )
184
+
115
185
residual = None
186
+ first_hidden_states , first_residual = self .first_layer_group (
187
+ positions ,
188
+ hidden_states ,
189
+ residual , # no residual, assume no pipeline parallel
190
+ )
116
191
117
192
decode_indices = get_forward_context ().decode_indices
118
193
if decode_indices is None :
119
194
decode_indices = torch .arange (positions .size (0 ),
120
195
device = positions .device )
121
-
122
- # Forward with full inputs up to the first layer that shares KV cache
123
- for layer in self .layers [self .start_layer :START_KV_SHARING_LAYER ]:
124
- hidden_states , residual = layer (
125
- positions ,
126
- hidden_states ,
127
- residual ,
128
- )
129
-
130
- if decode_indices is not None :
131
- decode_hidden_states = hidden_states [decode_indices ]
132
- decode_positions = positions [decode_indices ]
133
- decode_residual = (residual [decode_indices ]
134
- if residual is not None else None )
135
- else :
136
- decode_hidden_states = hidden_states
137
- decode_positions = positions
138
- decode_residual = residual
139
-
140
- # Optimization: forward with partial inputs only for last N layers
141
- for layer in self .layers [START_KV_SHARING_LAYER :self .end_layer ]:
142
- decode_hidden_states , decode_residual = layer (
143
- decode_positions ,
144
- decode_hidden_states ,
145
- decode_residual ,
146
- )
196
+ num_decodes = decode_indices .shape [0 ]
197
+ assert num_decodes >= 1
198
+ assert first_residual is not None
199
+
200
+ # CUDA graph expects static tensor addresses
201
+ # Copy output of first layer group to second layer group
202
+ self .residual [:num_decodes ].copy_ (first_residual [decode_indices ])
203
+ hidden_states [:num_decodes ].copy_ (first_hidden_states [decode_indices ])
204
+ positions [:num_decodes ].copy_ (positions [decode_indices ])
205
+
206
+ second_hidden_states , second_residual = self .second_layer_group (
207
+ positions [:num_decodes ],
208
+ hidden_states [:num_decodes ],
209
+ self .residual [:num_decodes ],
210
+ )
147
211
148
212
# Merge results back
149
- if decode_hidden_states is not None :
150
- hidden_states [decode_indices ] = decode_hidden_states
151
- if residual is not None :
152
- residual [decode_indices ] = decode_residual
213
+ first_hidden_states [decode_indices ] = second_hidden_states
214
+ if first_residual is not None :
215
+ first_residual [decode_indices ] = second_residual
153
216
154
- hidden_states , _ = self .norm (hidden_states , residual )
217
+ hidden_states , _ = self .norm (first_hidden_states , first_residual )
155
218
return hidden_states
156
219
157
220
@@ -205,20 +268,24 @@ def load_weights(self, weights: Iterable[tuple[str,
205
268
return loader .load_weights (weights )
206
269
207
270
208
- # TODO: make it work with torch.compile
209
271
@fork_new_process_for_each_test
210
- @pytest .mark .parametrize ("enforce_eager" , [True ])
272
+ @pytest .mark .parametrize ("enforce_eager" , [False , True ])
211
273
def test_kv_sharing_skip_prefill (monkeypatch , enforce_eager ):
212
274
prompt = "What is the capital of France?"
213
275
ModelRegistry .register_model ("Qwen2ForCausalLM" , TestQwen2ForCausalLM )
214
- sampling_params = SamplingParams (temperature = 0.0 , max_tokens = 40 )
276
+ sampling_params = SamplingParams (temperature = 0.0 , max_tokens = 100 )
215
277
single_prompt = [prompt ]
278
+ compilation_config = CompilationConfig (
279
+ level = CompilationLevel .PIECEWISE
280
+ if not enforce_eager else CompilationLevel .NO_COMPILATION ,
281
+ cudagraph_share_memory_pool = False )
216
282
217
283
with monkeypatch .context () as m :
218
284
m .setenv ("VLLM_USE_V1" , "1" )
219
285
220
286
llm = LLM (model = "Qwen/Qwen2-1.5B-Instruct" ,
221
- enforce_eager = enforce_eager )
287
+ enforce_eager = enforce_eager ,
288
+ compilation_config = compilation_config )
222
289
responses = llm .generate (single_prompt , sampling_params )
223
290
ref_output = responses [0 ].outputs [0 ].text
224
291
@@ -229,7 +296,8 @@ def test_kv_sharing_skip_prefill(monkeypatch, enforce_eager):
229
296
m .setenv ("VLLM_V1_KV_SHARING_SKIP_PREFILL" , "1" )
230
297
231
298
llm = LLM (model = "Qwen/Qwen2-1.5B-Instruct" ,
232
- enforce_eager = enforce_eager )
299
+ enforce_eager = enforce_eager ,
300
+ compilation_config = compilation_config )
233
301
responses = llm .generate (single_prompt , sampling_params )
234
302
output = responses [0 ].outputs [0 ].text
235
303
assert output == ref_output
0 commit comments