Skip to content

Commit c9caf3d

Browse files
author
Vincent Moens
committed
[Feature] Support lazy tensordict inputs in ppo loss
ghstack-source-id: 89098ba Pull Request resolved: #2883
1 parent 3e1f4ff commit c9caf3d

File tree

3 files changed

+17
-14
lines changed

3 files changed

+17
-14
lines changed

torchrl/modules/llm/vllm_wrapper.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def __init__(
223223
if from_text:
224224
self.out_keys += [self.text_response_key, self.token_key]
225225
if self.return_log_probs:
226-
self.out_keys += ["log_probs"]
226+
self.out_keys += [self.log_prob_key]
227227

228228
def forward(
229229
self,
@@ -303,7 +303,7 @@ def _from_vllm_generate_text(self, td):
303303
),
304304
)
305305
in_keys = [
306-
"log_probs",
306+
self.log_prob_key,
307307
self.token_response_key,
308308
self.text_response_key,
309309
self.token_key,
@@ -394,7 +394,7 @@ def _from_vllm_logprobs_text(self, td):
394394
if isinstance(input_ids_response, list):
395395
input_ids_response = torch.nested.nested_tensor(input_ids_response)
396396
out["tokens_response"] = input_ids_response
397-
out["log_probs"] = lps
397+
out[self.log_prob_key] = lps
398398
inputs = td.select(*self.in_keys, strict=False)
399399
if inputs.ndim < out.ndim:
400400
# This happens when n > 1
@@ -423,18 +423,19 @@ def _from_vllm_generate_tokens(self, td):
423423
).to_padded_tensor(padding=self.padding_value)
424424
tokens_response_td.rename_key_("token_ids", "tokens_response")
425425
if self.return_log_probs:
426-
tokens_response_td.rename_key_("logprobs", "log_probs")
426+
tokens_response_td.rename_key_("logprobs", self.log_prob_key)
427427
if self.pad_output:
428428
padded_values = (
429429
tokens_response_td["tokens_response"] == self.padding_value
430430
)
431431
if padded_values.any():
432-
lps = tokens_response_td["log_probs"]
432+
lps = tokens_response_td[self.log_prob_key]
433433
lps = torch.where(expand_as_right(~padded_values, lps), lps, 0.0)
434-
tokens_response_td["log_probs"] = lps
434+
tokens_response_td[self.log_prob_key] = lps
435435
out = tokens_response_td.empty(recurse=True)
436436
out.update(
437-
tokens_response_td, keys_to_update=(self.token_response_key, "log_probs")
437+
tokens_response_td,
438+
keys_to_update=(self.token_response_key, self.log_prob_key),
438439
)
439440
inputs = td.select(*self.in_keys, strict=False)
440441
if inputs.ndim < out.ndim:
@@ -467,7 +468,7 @@ def _from_vllm_logprobs_tokens(self, td):
467468
padded = tokens_response == self.padding_value
468469
prompt_logprobs = torch.where(~padded, prompt_logprobs, 0.0)
469470
out = tokens_out._tensordict.empty(recurse=True)
470-
out.set("log_probs", prompt_logprobs)
471+
out.set(self.log_prob_key, prompt_logprobs)
471472
out.set(self.token_response_key, tokens_response)
472473
inputs = td.select(*self.in_keys, strict=False)
473474
if inputs.ndim < out.ndim:
@@ -501,13 +502,13 @@ def _get_output_tokens_and_log_probs(self, tokens_out):
501502
)
502503

503504
if self.return_log_probs or "logprobs" in tokens_response_td:
504-
tokens_response_td.rename_key_("logprobs", "log_probs")
505+
tokens_response_td.rename_key_("logprobs", self.log_prob_key)
505506
if self.pad_output:
506507
padded_values = tokens_response_td["tokens_response"] == padding_value
507508
if padded_values.any():
508-
lps = tokens_response_td["log_probs"]
509+
lps = tokens_response_td[self.log_prob_key]
509510
lps = torch.where(expand_as_right(~padded_values, lps), lps, 0.0)
510-
tokens_response_td["log_probs"] = lps
511+
tokens_response_td[self.log_prob_key] = lps
511512
return tokens_response_td
512513

513514
def _to_list(self, tokens, attention_mask):

torchrl/objectives/ppo.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ def _get_cur_log_prob(self, tensordict):
533533
if isinstance(
534534
self.actor_network,
535535
(ProbabilisticTensorDictSequential, ProbabilisticTensorDictModule),
536-
):
536+
) or hasattr(self.actor_network, "get_dist"):
537537
# assert tensordict['log_probs'].requires_grad
538538
# assert tensordict['logits'].requires_grad
539539
with self.actor_network_params.to_module(
@@ -987,7 +987,9 @@ def out_keys(self, values):
987987
@dispatch
988988
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
989989
tensordict = tensordict.clone(False)
990-
advantage = tensordict.get(self.tensor_keys.advantage, None)
990+
advantage = tensordict.get(
991+
self.tensor_keys.advantage, None, as_padded_tensor=True
992+
)
991993
if advantage is None:
992994
if self.critic_network is None:
993995
raise RuntimeError(

torchrl/objectives/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,7 @@ def _sum_td_features(data: TensorDictBase) -> torch.Tensor:
624624

625625
def _maybe_get_or_select(td, key_or_keys, target_shape=None):
626626
if isinstance(key_or_keys, (str, tuple)):
627-
return td.get(key_or_keys)
627+
return td.get(key_or_keys, as_padded_tensor=True)
628628
result = td.select(*key_or_keys)
629629
if target_shape is not None and result.shape != target_shape:
630630
result.batch_size = target_shape

0 commit comments

Comments
 (0)