10
10
import torch
11
11
from tensordict import (
12
12
lazy_stack ,
13
- LazyStackedTensorDict ,
13
+ maybe_dense_stack ,
14
14
NestedKey ,
15
15
TensorDict ,
16
- TensorDictBase , maybe_dense_stack ,
16
+ TensorDictBase ,
17
17
)
18
18
from tensordict .tensorclass import from_dataclass , NonTensorStack , TensorClass
19
19
from tensordict .utils import _zip_strict , expand_as_right
@@ -61,7 +61,8 @@ class vLLMWrapper(CategoricalSequential):
61
61
inplace (Literal[True, False, "empty"] | None, optional): Determines how the module should handle in-place
62
62
operations. If `True`, operations will be performed in-place. If `False`, a new TensorDict instance will be
63
63
created. If `"empty"`, the output data structure will be initialized with `input.empty()` (i.e., it will
64
- conserve type, batch-size, and device). Defaults to `True`.
64
+ conserve type, batch-size, and device). Defaults to `True` when generating a single sample, `False`
65
+ otherwise.
65
66
66
67
.. note:: The tokenizer is used when `from_text` is `True` to convert input text into token sequences. It is also
67
68
required (or retrieved) when `pad_output` is `True` or when using text inputs with `generate=False` to ensure proper
@@ -125,7 +126,7 @@ def __init__(
125
126
generate_kwargs : dict | None = None ,
126
127
tokenizer_kwargs : dict | None = None ,
127
128
pad_output : bool = False ,
128
- inplace : Literal [True , False , "empty" ] | None = True ,
129
+ inplace : Literal [True , False , "empty" ] | None = None ,
129
130
):
130
131
super ().__init__ ()
131
132
@@ -135,7 +136,6 @@ def __init__(
135
136
self .from_text = from_text
136
137
self ._device = device
137
138
self .generate = generate
138
- self .inplace = inplace
139
139
self .pad_output = pad_output
140
140
padding_value = None
141
141
@@ -180,6 +180,18 @@ def __init__(
180
180
else :
181
181
generate_kwargs = dict (generate_kwargs )
182
182
183
+ if generate_kwargs .get ("n" , 1 ) > 1 :
184
+ if inplace in (True , "empty" ):
185
+ raise ValueError (
186
+ "inplace must be False (or None) when generating more than one sample."
187
+ )
188
+ if inplace is None :
189
+ inplace = False
190
+ elif inplace is None :
191
+ inplace = True
192
+
193
+ self .inplace = inplace
194
+
183
195
prompt_logprobs = False
184
196
185
197
if not generate :
@@ -225,45 +237,39 @@ def forward(
225
237
if tensordict .device :
226
238
tensordict = tensordict .copy ().clear_device_ ()
227
239
228
- out = LazyStackedTensorDict (
229
- * [
230
- TensorDict (
231
- device = tensordict .device , batch_size = tensordict .batch_size [1 :]
232
- )
233
- for _ in range (tensordict .shape [0 ])
234
- ]
235
- )
236
240
if self .from_text :
237
241
if self .generate :
238
- out = self ._from_vllm_generate_text (tensordict , out = out )
242
+ out = self ._from_vllm_generate_text (tensordict )
239
243
else :
240
- out = self ._from_vllm_logprobs_text (tensordict , out = out )
244
+ out = self ._from_vllm_logprobs_text (tensordict )
241
245
else :
242
246
if self .generate :
243
- out = self ._from_vllm_generate_tokens (tensordict , out = out )
247
+ out = self ._from_vllm_generate_tokens (tensordict )
244
248
else :
245
- out = self ._from_vllm_logprobs_tokens (tensordict , out = out )
249
+ out = self ._from_vllm_logprobs_tokens (tensordict )
246
250
if _source_device :
247
251
out = out .to (_source_device )
248
252
249
253
if tensordict_out is None :
250
254
if self .inplace is True :
251
255
tensordict_out = tensordict
252
256
elif self .inplace is False :
253
- tensordict_out = TensorDict ()
257
+ tensordict_out = out
254
258
elif self .inplace == "empty" :
255
259
tensordict_out = tensordict .empty ()
256
260
257
- if tensordict_out is not None :
261
+ if tensordict_out is not None and tensordict_out is not out :
258
262
result = tensordict_out
259
263
result .update (out , keys_to_update = self .out_keys )
260
- else :
264
+ elif tensordict_out is not out :
261
265
result = out
262
266
keys = list (set (self .out_keys + list (tensordict .keys (True , True ))))
263
267
return tensordict .update (result , keys_to_update = keys )
268
+ else :
269
+ result = out
264
270
return result
265
271
266
- def _from_vllm_generate_text (self , td , out ):
272
+ def _from_vllm_generate_text (self , td ):
267
273
kwargs = {"sampling_params" : self .sampling_params }
268
274
args = ()
269
275
input_ids = None
@@ -301,16 +307,22 @@ def _from_vllm_generate_text(self, td, out):
301
307
self .token_response_key ,
302
308
self .text_response_key ,
303
309
self .token_key ,
310
+ self .attention_mask_key ,
304
311
]
305
- out . update ( tokens_out , keys_to_update = in_keys )
312
+ out = tokens_out . select ( * in_keys , strict = False )
306
313
# We might already have the tokens
307
- if input_ids is not None :
314
+ if input_ids is not None and self . token_key not in out :
308
315
out [self .token_key ] = input_ids
309
- if attention_mask is not None :
316
+ if attention_mask is not None and self . attention_mask_key not in out :
310
317
out [self .attention_mask_key ] = attention_mask
318
+ inputs = td .select (* self .in_keys , strict = False )
319
+ if inputs .ndim < out .ndim :
320
+ # This happens when n > 1
321
+ inputs = inputs .unsqueeze (- 1 ).expand (out .shape )
322
+ out .update (inputs )
311
323
return out
312
324
313
- def _from_vllm_logprobs_text (self , td , out ):
325
+ def _from_vllm_logprobs_text (self , td ):
314
326
text_prompt = td .get (self .text_key )
315
327
if not isinstance (text_prompt , list ):
316
328
text_prompt = text_prompt .tolist ()
@@ -358,7 +370,7 @@ def _from_vllm_logprobs_text(self, td, out):
358
370
tokens_out = _RequestOutput_tc .from_request_output (tokens_out )
359
371
tokens_out = tokens_out .select (
360
372
"prompt_token_ids" , "prompt_logprobs" , strict = False
361
- )
373
+ ). _tensordict
362
374
363
375
# we disregard the tokens from the prompt to focus on those of the response
364
376
if self .pad_output :
@@ -378,13 +390,19 @@ def _from_vllm_logprobs_text(self, td, out):
378
390
[lp [..., - len (tr ) :] for lp , tr in zip (lps , input_ids_response )]
379
391
)
380
392
393
+ out = tokens_out .empty (recurse = True )
381
394
if isinstance (input_ids_response , list ):
382
395
input_ids_response = torch .nested .nested_tensor (input_ids_response )
383
396
out ["tokens_response" ] = input_ids_response
384
397
out ["log_probs" ] = lps
398
+ inputs = td .select (* self .in_keys , strict = False )
399
+ if inputs .ndim < out .ndim :
400
+ # This happens when n > 1
401
+ inputs = inputs .unsqueeze (- 1 ).expand (out .shape )
402
+ out .update (inputs )
385
403
return out
386
404
387
- def _from_vllm_generate_tokens (self , td , out ):
405
+ def _from_vllm_generate_tokens (self , td ):
388
406
input_ids = td .get (self .token_key )
389
407
attention_mask = td .get (self .attention_mask_key )
390
408
input_ids_list = self ._to_list (input_ids , attention_mask )
@@ -414,12 +432,18 @@ def _from_vllm_generate_tokens(self, td, out):
414
432
lps = tokens_response_td ["log_probs" ]
415
433
lps = torch .where (expand_as_right (~ padded_values , lps ), lps , 0.0 )
416
434
tokens_response_td ["log_probs" ] = lps
435
+ out = tokens_response_td .empty (recurse = True )
417
436
out .update (
418
437
tokens_response_td , keys_to_update = (self .token_response_key , "log_probs" )
419
438
)
439
+ inputs = td .select (* self .in_keys , strict = False )
440
+ if inputs .ndim < out .ndim :
441
+ # This happens when n > 1
442
+ inputs = inputs .unsqueeze (- 1 ).expand (out .shape )
443
+ out .update (inputs )
420
444
return out
421
445
422
- def _from_vllm_logprobs_tokens (self , td , out ):
446
+ def _from_vllm_logprobs_tokens (self , td ):
423
447
424
448
tokens = td .get (self .token_key )
425
449
tokens_response = td .get (self .token_response_key )
@@ -442,8 +466,14 @@ def _from_vllm_logprobs_tokens(self, td, out):
442
466
prompt_logprobs = prompt_logprobs [..., - tokens_response .shape [- 1 ] :]
443
467
padded = tokens_response == self .padding_value
444
468
prompt_logprobs = torch .where (~ padded , prompt_logprobs , 0.0 )
469
+ out = tokens_out ._tensordict .empty (recurse = True )
445
470
out .set ("log_probs" , prompt_logprobs )
446
471
out .set (self .token_response_key , tokens_response )
472
+ inputs = td .select (* self .in_keys , strict = False )
473
+ if inputs .ndim < out .ndim :
474
+ # This happens when n > 1
475
+ inputs = inputs .unsqueeze (- 1 ).expand (out .shape )
476
+ out .update (inputs )
447
477
return out
448
478
449
479
def _get_output_tokens_and_log_probs (self , tokens_out ):
@@ -463,19 +493,21 @@ def _get_output_tokens_and_log_probs(self, tokens_out):
463
493
if not self .pad_output :
464
494
# Then we can safely move the input tokens, but otherwise they
465
495
# may need padding
466
- tokens_response_td .update (
467
- tokens_out .select ("prompt_token_ids" )
468
- ).rename_key_ ("prompt_token_ids" , self .token_key )
496
+ tokens_out = tokens_out .select ("prompt_token_ids" )
497
+ if tokens_out .ndim < tokens_response_td .ndim :
498
+ tokens_out = tokens_out .unsqueeze (1 ).expand (tokens_response_td .shape )
499
+ tokens_response_td .update (tokens_out ).rename_key_ (
500
+ "prompt_token_ids" , self .token_key
501
+ )
469
502
470
- if self .return_log_probs :
503
+ if self .return_log_probs or "logprobs" in tokens_response_td :
471
504
tokens_response_td .rename_key_ ("logprobs" , "log_probs" )
472
505
if self .pad_output :
473
506
padded_values = tokens_response_td ["tokens_response" ] == padding_value
474
507
if padded_values .any ():
475
508
lps = tokens_response_td ["log_probs" ]
476
509
lps = torch .where (expand_as_right (~ padded_values , lps ), lps , 0.0 )
477
510
tokens_response_td ["log_probs" ] = lps
478
-
479
511
return tokens_response_td
480
512
481
513
def _to_list (self , tokens , attention_mask ):
@@ -553,14 +585,15 @@ def get_logprob(output):
553
585
self .outputs = outputs [0 ]
554
586
else :
555
587
self .outputs = maybe_dense_stack (outputs )
556
- self .prompt_logprobs = torch .tensor (
557
- [
558
- v [tid ].logprob if v is not None else 0.0
559
- for v , tid in _zip_strict (
560
- self .prompt_logprobs , self .prompt_token_ids
561
- )
562
- ]
563
- )
588
+ if self .prompt_logprobs is not None :
589
+ self .prompt_logprobs = torch .tensor (
590
+ [
591
+ v [tid ].logprob if v is not None else 0.0
592
+ for v , tid in _zip_strict (
593
+ self .prompt_logprobs , self .prompt_token_ids
594
+ )
595
+ ]
596
+ )
564
597
self .prompt_token_ids = torch .tensor (self .prompt_token_ids )
565
598
self .num_cached_tokens = torch .tensor (self .num_cached_tokens )
566
599
0 commit comments