2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
4
4
import gc
5
+ import random
5
6
from collections .abc import Iterable
6
- from typing import Any , Optional , Union
7
+ from typing import Optional , Union
7
8
8
9
import pytest
9
10
import torch
@@ -105,7 +106,7 @@ def forward(
105
106
106
107
107
108
@support_torch_compile
108
- class DecoderLayerGroup (nn .Module ):
109
+ class FirstLayerGroup (nn .Module ):
109
110
110
111
def __init__ (
111
112
self ,
@@ -121,7 +122,35 @@ def forward(
121
122
self ,
122
123
positions : torch .Tensor ,
123
124
hidden_states : torch .Tensor ,
124
- residual : Optional [torch .Tensor ] = None ,
125
+ ):
126
+ residual = None
127
+ for layer in self .layers :
128
+ hidden_states , residual = layer (
129
+ positions ,
130
+ hidden_states ,
131
+ residual ,
132
+ )
133
+ return hidden_states , residual
134
+
135
+
136
+ @support_torch_compile
137
+ class SecondLayerGroup (nn .Module ):
138
+
139
+ def __init__ (
140
+ self ,
141
+ * ,
142
+ vllm_config : VllmConfig ,
143
+ prefix : str = "" ,
144
+ layers : list [nn .Module ],
145
+ ):
146
+ super ().__init__ ()
147
+ self .layers = layers
148
+
149
+ def forward (
150
+ self ,
151
+ positions : torch .Tensor ,
152
+ hidden_states : torch .Tensor ,
153
+ residual : torch .Tensor ,
125
154
):
126
155
for layer in self .layers :
127
156
hidden_states , residual = layer (
@@ -147,15 +176,17 @@ def __init__(self,
147
176
decoder_layer_type = decoder_layer_type ,
148
177
)
149
178
179
+ self .vllm_config = vllm_config
180
+
150
181
with set_model_tag ("first_layer_group" ):
151
- self .first_layer_group = DecoderLayerGroup (
182
+ self .first_layer_group = FirstLayerGroup (
152
183
vllm_config = vllm_config ,
153
184
prefix = f"{ prefix } .first_layer_group" ,
154
185
layers = self .layers [self .start_layer :START_KV_SHARING_LAYER ],
155
186
)
156
187
157
188
with set_model_tag ("second_layer_group" ):
158
- self .second_layer_group = DecoderLayerGroup (
189
+ self .second_layer_group = SecondLayerGroup (
159
190
vllm_config = vllm_config ,
160
191
prefix = f"{ prefix } .second_layer_group" ,
161
192
layers = self .layers [START_KV_SHARING_LAYER :self .end_layer ],
@@ -170,6 +201,10 @@ def __init__(self,
170
201
self .residual = torch .zeros ((self .max_num_tokens , self .hidden_size ),
171
202
dtype = self .dtype ,
172
203
device = self .device )
204
+ self .hidden_states = torch .zeros (
205
+ (self .max_num_tokens , self .hidden_size ),
206
+ dtype = self .dtype ,
207
+ device = self .device )
173
208
174
209
def forward (
175
210
self ,
@@ -183,11 +218,12 @@ def forward(
183
218
else :
184
219
hidden_states = self .get_input_embeddings (input_ids )
185
220
186
- residual = None
221
+ num_input_tokens = input_ids .size (0 )
222
+ self .hidden_states [:num_input_tokens ].copy_ (hidden_states )
223
+
187
224
first_hidden_states , first_residual = self .first_layer_group (
188
225
positions ,
189
- hidden_states ,
190
- residual , # no residual, assume no pipeline parallel
226
+ self .hidden_states [:num_input_tokens ],
191
227
)
192
228
193
229
decode_indices = get_forward_context ().decode_indices
@@ -202,15 +238,24 @@ def forward(
202
238
# CUDA graph expects static tensor addresses
203
239
# Copy output of first layer group to second layer group
204
240
self .residual [:num_decodes ].copy_ (first_residual [decode_indices ])
205
- hidden_states [:num_decodes ].copy_ (first_hidden_states [decode_indices ])
241
+ self .hidden_states [:num_decodes ].copy_ (
242
+ first_hidden_states [decode_indices ])
206
243
positions [:num_decodes ].copy_ (positions [decode_indices ])
207
244
208
245
second_hidden_states , second_residual = self .second_layer_group (
209
246
positions [:num_decodes ],
210
- hidden_states [:num_decodes ],
247
+ self . hidden_states [:num_decodes ],
211
248
self .residual [:num_decodes ],
212
249
)
213
250
251
+ # NOTE(sarckk): Due to cudagraph padding, decode_indices may have
252
+ # trailing repeated indices. Attention output is only valid at the
253
+ # last index in this case.
254
+ last_index_mask = decode_indices == decode_indices [- 1 ]
255
+ second_hidden_states [last_index_mask ] = second_hidden_states [- 1 ].clone (
256
+ )
257
+ second_residual [last_index_mask ] = second_residual [- 1 ].clone ()
258
+
214
259
# Merge results back
215
260
first_hidden_states [decode_indices ] = second_hidden_states
216
261
if first_residual is not None :
@@ -270,16 +315,43 @@ def load_weights(self, weights: Iterable[tuple[str,
270
315
return loader .load_weights (weights )
271
316
272
317
318
+ @pytest .fixture
319
+ def test_prompts ():
320
+ prompt_types = ["repeat" , "sentence" ]
321
+ # Setting higher num prompts increases the chance of numerics mismatch
322
+ # due to matrix multiplication numerics depending on batch dimension
323
+ num_prompts = 10
324
+ prompts = []
325
+
326
+ random .seed (0 )
327
+ random_prompt_type_choices = random .choices (prompt_types , k = num_prompts )
328
+
329
+ # Generate a mixed batch of prompts, some of which can be easily
330
+ # predicted by n-gram matching and some which likely cannot.
331
+ for kind in random_prompt_type_choices :
332
+ word_choices = ["test" , "temp" , "hello" , "where" ]
333
+ word = random .choice (word_choices )
334
+ if kind == "repeat" :
335
+ prompt = f"""please repeat the word '{ word } ' 10 times."""
336
+ elif kind == "sentence" :
337
+ prompt = f"""please give a ten-word sentence that
338
+ uses the word { word } at least once."""
339
+ else :
340
+ raise ValueError (f"Unknown prompt type: { kind } " )
341
+ prompts .append (prompt )
342
+
343
+ return prompts
344
+
345
+
273
346
@fork_new_process_for_each_test
274
347
@pytest .mark .parametrize ("enforce_eager" , [True , False ])
275
348
def test_kv_sharing_skip_prefill (
276
349
monkeypatch : pytest .MonkeyPatch ,
277
350
enforce_eager : bool ,
278
- test_prompts : list [list [ dict [ str , Any ]] ],
351
+ test_prompts : list [str ],
279
352
):
280
353
ModelRegistry .register_model ("Qwen2ForCausalLM" , TestQwen2ForCausalLM )
281
354
sampling_params = SamplingParams (temperature = 0.0 , max_tokens = 100 )
282
- prompts = [prompt [0 ]['content' ] for prompt in test_prompts ]
283
355
compilation_config = CompilationConfig (
284
356
level = CompilationLevel .PIECEWISE
285
357
if not enforce_eager else CompilationLevel .NO_COMPILATION ,
@@ -293,8 +365,7 @@ def test_kv_sharing_skip_prefill(
293
365
enforce_eager = enforce_eager ,
294
366
compilation_config = compilation_config ,
295
367
)
296
- responses = llm .generate (prompts , sampling_params )
297
- ref_output = responses [0 ].outputs [0 ].text
368
+ ref_responses = llm .generate (test_prompts , sampling_params )
298
369
299
370
del llm
300
371
gc .collect ()
@@ -304,6 +375,14 @@ def test_kv_sharing_skip_prefill(
304
375
enforce_eager = enforce_eager ,
305
376
compilation_config = compilation_config ,
306
377
kv_sharing_skip_prefill = True )
307
- responses = llm .generate (prompts , sampling_params )
308
- output = responses [0 ].outputs [0 ].text
309
- assert output == ref_output
378
+ optimized_responses = llm .generate (test_prompts , sampling_params )
379
+
380
+ misses = 0
381
+
382
+ for ref_response , optimized_response in zip (ref_responses ,
383
+ optimized_responses ):
384
+ if ref_response .outputs [0 ].text != optimized_response .outputs [
385
+ 0 ].text :
386
+ misses += 1
387
+
388
+ assert misses == 0
0 commit comments