Skip to content

Commit 48111e9

Browse files
author
Vincent Moens
committed
[Feature] Allow vLLMWrapper to generate multiple samples
ghstack-source-id: af58943 Pull Request resolved: #2878
1 parent afc5d59 commit 48111e9

File tree

2 files changed

+86
-41
lines changed

2 files changed

+86
-41
lines changed

test/test_actors.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,6 +1374,18 @@ def _run_check_collector(self, policy):
13741374
assert "tokens" in data
13751375
# assert ("next", "tokens") in data
13761376

1377+
def test_generate_multiple_trajs_vllm(self, vllm_instance):
1378+
policy = vLLMWrapper(
1379+
vllm_instance,
1380+
return_log_probs=True,
1381+
generate_kwargs={"n": 10, "max_tokens": 1024},
1382+
inplace=False,
1383+
)
1384+
data = TensorDict(
1385+
text=NonTensorStack("a string", "another very long string"), batch_size=2
1386+
)
1387+
data = policy(data)
1388+
13771389

13781390
if __name__ == "__main__":
13791391
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/modules/llm/vllm_wrapper.py

Lines changed: 74 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
import torch
1111
from tensordict import (
1212
lazy_stack,
13-
LazyStackedTensorDict,
13+
maybe_dense_stack,
1414
NestedKey,
1515
TensorDict,
16-
TensorDictBase, maybe_dense_stack,
16+
TensorDictBase,
1717
)
1818
from tensordict.tensorclass import from_dataclass, NonTensorStack, TensorClass
1919
from tensordict.utils import _zip_strict, expand_as_right
@@ -61,7 +61,8 @@ class vLLMWrapper(CategoricalSequential):
6161
inplace (Literal[True, False, "empty"] | None, optional): Determines how the module should handle in-place
6262
operations. If `True`, operations will be performed in-place. If `False`, a new TensorDict instance will be
6363
created. If `"empty"`, the output data structure will be initialized with `input.empty()` (i.e., it will
64-
conserve type, batch-size, and device). Defaults to `True`.
64+
conserve type, batch-size, and device). Defaults to `True` when generating a single sample, `False`
65+
otherwise.
6566
6667
.. note:: The tokenizer is used when `from_text` is `True` to convert input text into token sequences. It is also
6768
required (or retrieved) when `pad_output` is `True` or when using text inputs with `generate=False` to ensure proper
@@ -125,7 +126,7 @@ def __init__(
125126
generate_kwargs: dict | None = None,
126127
tokenizer_kwargs: dict | None = None,
127128
pad_output: bool = False,
128-
inplace: Literal[True, False, "empty"] | None = True,
129+
inplace: Literal[True, False, "empty"] | None = None,
129130
):
130131
super().__init__()
131132

@@ -135,7 +136,6 @@ def __init__(
135136
self.from_text = from_text
136137
self._device = device
137138
self.generate = generate
138-
self.inplace = inplace
139139
self.pad_output = pad_output
140140
padding_value = None
141141

@@ -180,6 +180,18 @@ def __init__(
180180
else:
181181
generate_kwargs = dict(generate_kwargs)
182182

183+
if generate_kwargs.get("n", 1) > 1:
184+
if inplace in (True, "empty"):
185+
raise ValueError(
186+
"inplace must be False (or None) when generating more than one sample."
187+
)
188+
if inplace is None:
189+
inplace = False
190+
elif inplace is None:
191+
inplace = True
192+
193+
self.inplace = inplace
194+
183195
prompt_logprobs = False
184196

185197
if not generate:
@@ -225,45 +237,39 @@ def forward(
225237
if tensordict.device:
226238
tensordict = tensordict.copy().clear_device_()
227239

228-
out = LazyStackedTensorDict(
229-
*[
230-
TensorDict(
231-
device=tensordict.device, batch_size=tensordict.batch_size[1:]
232-
)
233-
for _ in range(tensordict.shape[0])
234-
]
235-
)
236240
if self.from_text:
237241
if self.generate:
238-
out = self._from_vllm_generate_text(tensordict, out=out)
242+
out = self._from_vllm_generate_text(tensordict)
239243
else:
240-
out = self._from_vllm_logprobs_text(tensordict, out=out)
244+
out = self._from_vllm_logprobs_text(tensordict)
241245
else:
242246
if self.generate:
243-
out = self._from_vllm_generate_tokens(tensordict, out=out)
247+
out = self._from_vllm_generate_tokens(tensordict)
244248
else:
245-
out = self._from_vllm_logprobs_tokens(tensordict, out=out)
249+
out = self._from_vllm_logprobs_tokens(tensordict)
246250
if _source_device:
247251
out = out.to(_source_device)
248252

249253
if tensordict_out is None:
250254
if self.inplace is True:
251255
tensordict_out = tensordict
252256
elif self.inplace is False:
253-
tensordict_out = TensorDict()
257+
tensordict_out = out
254258
elif self.inplace == "empty":
255259
tensordict_out = tensordict.empty()
256260

257-
if tensordict_out is not None:
261+
if tensordict_out is not None and tensordict_out is not out:
258262
result = tensordict_out
259263
result.update(out, keys_to_update=self.out_keys)
260-
else:
264+
elif tensordict_out is not out:
261265
result = out
262266
keys = list(set(self.out_keys + list(tensordict.keys(True, True))))
263267
return tensordict.update(result, keys_to_update=keys)
268+
else:
269+
result = out
264270
return result
265271

266-
def _from_vllm_generate_text(self, td, out):
272+
def _from_vllm_generate_text(self, td):
267273
kwargs = {"sampling_params": self.sampling_params}
268274
args = ()
269275
input_ids = None
@@ -301,16 +307,22 @@ def _from_vllm_generate_text(self, td, out):
301307
self.token_response_key,
302308
self.text_response_key,
303309
self.token_key,
310+
self.attention_mask_key,
304311
]
305-
out.update(tokens_out, keys_to_update=in_keys)
312+
out = tokens_out.select(*in_keys, strict=False)
306313
# We might already have the tokens
307-
if input_ids is not None:
314+
if input_ids is not None and self.token_key not in out:
308315
out[self.token_key] = input_ids
309-
if attention_mask is not None:
316+
if attention_mask is not None and self.attention_mask_key not in out:
310317
out[self.attention_mask_key] = attention_mask
318+
inputs = td.select(*self.in_keys, strict=False)
319+
if inputs.ndim < out.ndim:
320+
# This happens when n > 1
321+
inputs = inputs.unsqueeze(-1).expand(out.shape)
322+
out.update(inputs)
311323
return out
312324

313-
def _from_vllm_logprobs_text(self, td, out):
325+
def _from_vllm_logprobs_text(self, td):
314326
text_prompt = td.get(self.text_key)
315327
if not isinstance(text_prompt, list):
316328
text_prompt = text_prompt.tolist()
@@ -358,7 +370,7 @@ def _from_vllm_logprobs_text(self, td, out):
358370
tokens_out = _RequestOutput_tc.from_request_output(tokens_out)
359371
tokens_out = tokens_out.select(
360372
"prompt_token_ids", "prompt_logprobs", strict=False
361-
)
373+
)._tensordict
362374

363375
# we disregard the tokens from the prompt to focus on those of the response
364376
if self.pad_output:
@@ -378,13 +390,19 @@ def _from_vllm_logprobs_text(self, td, out):
378390
[lp[..., -len(tr) :] for lp, tr in zip(lps, input_ids_response)]
379391
)
380392

393+
out = tokens_out.empty(recurse=True)
381394
if isinstance(input_ids_response, list):
382395
input_ids_response = torch.nested.nested_tensor(input_ids_response)
383396
out["tokens_response"] = input_ids_response
384397
out["log_probs"] = lps
398+
inputs = td.select(*self.in_keys, strict=False)
399+
if inputs.ndim < out.ndim:
400+
# This happens when n > 1
401+
inputs = inputs.unsqueeze(-1).expand(out.shape)
402+
out.update(inputs)
385403
return out
386404

387-
def _from_vllm_generate_tokens(self, td, out):
405+
def _from_vllm_generate_tokens(self, td):
388406
input_ids = td.get(self.token_key)
389407
attention_mask = td.get(self.attention_mask_key)
390408
input_ids_list = self._to_list(input_ids, attention_mask)
@@ -414,12 +432,18 @@ def _from_vllm_generate_tokens(self, td, out):
414432
lps = tokens_response_td["log_probs"]
415433
lps = torch.where(expand_as_right(~padded_values, lps), lps, 0.0)
416434
tokens_response_td["log_probs"] = lps
435+
out = tokens_response_td.empty(recurse=True)
417436
out.update(
418437
tokens_response_td, keys_to_update=(self.token_response_key, "log_probs")
419438
)
439+
inputs = td.select(*self.in_keys, strict=False)
440+
if inputs.ndim < out.ndim:
441+
# This happens when n > 1
442+
inputs = inputs.unsqueeze(-1).expand(out.shape)
443+
out.update(inputs)
420444
return out
421445

422-
def _from_vllm_logprobs_tokens(self, td, out):
446+
def _from_vllm_logprobs_tokens(self, td):
423447

424448
tokens = td.get(self.token_key)
425449
tokens_response = td.get(self.token_response_key)
@@ -442,8 +466,14 @@ def _from_vllm_logprobs_tokens(self, td, out):
442466
prompt_logprobs = prompt_logprobs[..., -tokens_response.shape[-1] :]
443467
padded = tokens_response == self.padding_value
444468
prompt_logprobs = torch.where(~padded, prompt_logprobs, 0.0)
469+
out = tokens_out._tensordict.empty(recurse=True)
445470
out.set("log_probs", prompt_logprobs)
446471
out.set(self.token_response_key, tokens_response)
472+
inputs = td.select(*self.in_keys, strict=False)
473+
if inputs.ndim < out.ndim:
474+
# This happens when n > 1
475+
inputs = inputs.unsqueeze(-1).expand(out.shape)
476+
out.update(inputs)
447477
return out
448478

449479
def _get_output_tokens_and_log_probs(self, tokens_out):
@@ -463,19 +493,21 @@ def _get_output_tokens_and_log_probs(self, tokens_out):
463493
if not self.pad_output:
464494
# Then we can safely move the input tokens, but otherwise they
465495
# may need padding
466-
tokens_response_td.update(
467-
tokens_out.select("prompt_token_ids")
468-
).rename_key_("prompt_token_ids", self.token_key)
496+
tokens_out = tokens_out.select("prompt_token_ids")
497+
if tokens_out.ndim < tokens_response_td.ndim:
498+
tokens_out = tokens_out.unsqueeze(1).expand(tokens_response_td.shape)
499+
tokens_response_td.update(tokens_out).rename_key_(
500+
"prompt_token_ids", self.token_key
501+
)
469502

470-
if self.return_log_probs:
503+
if self.return_log_probs or "logprobs" in tokens_response_td:
471504
tokens_response_td.rename_key_("logprobs", "log_probs")
472505
if self.pad_output:
473506
padded_values = tokens_response_td["tokens_response"] == padding_value
474507
if padded_values.any():
475508
lps = tokens_response_td["log_probs"]
476509
lps = torch.where(expand_as_right(~padded_values, lps), lps, 0.0)
477510
tokens_response_td["log_probs"] = lps
478-
479511
return tokens_response_td
480512

481513
def _to_list(self, tokens, attention_mask):
@@ -553,14 +585,15 @@ def get_logprob(output):
553585
self.outputs = outputs[0]
554586
else:
555587
self.outputs = maybe_dense_stack(outputs)
556-
self.prompt_logprobs = torch.tensor(
557-
[
558-
v[tid].logprob if v is not None else 0.0
559-
for v, tid in _zip_strict(
560-
self.prompt_logprobs, self.prompt_token_ids
561-
)
562-
]
563-
)
588+
if self.prompt_logprobs is not None:
589+
self.prompt_logprobs = torch.tensor(
590+
[
591+
v[tid].logprob if v is not None else 0.0
592+
for v, tid in _zip_strict(
593+
self.prompt_logprobs, self.prompt_token_ids
594+
)
595+
]
596+
)
564597
self.prompt_token_ids = torch.tensor(self.prompt_token_ids)
565598
self.num_cached_tokens = torch.tensor(self.num_cached_tokens)
566599

0 commit comments

Comments
 (0)