Skip to content

Commit 19f3bd5

Browse files
committed
[BugFix] Wrappers stack fn (#3061)
1 parent be7156f commit 19f3bd5

File tree

4 files changed

+10
-8
lines changed

4 files changed

+10
-8
lines changed

test/llm/test_wrapper.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,6 +1192,8 @@ def test_retrieve_kl_input_modes(
11921192
ref_model=ref_model,
11931193
assistant_only=assistant_only,
11941194
tokenizer=tokenizer,
1195+
gen_log_probs_full_key=("gen_log_probs", "full"),
1196+
ref_log_probs_full_key=("ref_log_probs", "full"),
11951197
)
11961198

11971199
# Apply transform
@@ -1202,15 +1204,15 @@ def test_retrieve_kl_input_modes(
12021204
# Check that both log-probs and KL are present
12031205
assert ("gen_log_probs", "full") in result
12041206
assert ("ref_log_probs", "full") in result
1205-
assert "kl" in result
1207+
assert "kl_penalty" in result
12061208

12071209
# Check KL structure
12081210
if pad_output:
1209-
kl = result.get("kl")
1211+
kl = result.get("kl_penalty")
12101212
assert isinstance(kl, torch.Tensor)
12111213
assert kl.shape[0] == 2 # batch size
12121214
else:
1213-
kl = result.get("kl", as_list=True)
1215+
kl = result.get("kl_penalty", as_list=True)
12141216
# For unpadded output, we get a list of tensors
12151217
assert isinstance(kl, list)
12161218
assert len(kl) == 2 # batch size
@@ -1391,7 +1393,7 @@ def test_kl_computation_transform(
13911393

13921394
# Check that reward is modified
13931395
assert "reward" in result
1394-
reward = result.get("reward")
1396+
reward = result.get("reward", as_list=True)
13951397
assert reward is not None
13961398

13971399

torchrl/envs/llm/transforms/kl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -680,8 +680,8 @@ class RetrieveKL(Compose):
680680
For other input modes (`"text"` or `"tokens"`), set `assistant_only=False`.
681681
This ensures users are conscious of the limitation that assistant token identification requires structured conversation history.
682682
683-
gen_log_prob_full_key (str): the key where the log-probs of the generation model are stored. Defaults to `("log_probs", "full")`.
684-
ref_log_prob_full_key (str): the key where the log-probs of the reference model are stored. Defaults to `("ref_log_probs", "full")`.
683+
gen_log_probs_full_key (str): the key where the log-probs of the generation model are stored. Defaults to `("log_probs", "full")`.
684+
ref_log_probs_full_key (str): the key where the log-probs of the reference model are stored. Defaults to `("ref_log_probs", "full")`.
685685
history_key (str): the key where the history is stored. Defaults to `"history"`.
686686
tokenizer_kwargs (dict): the keyword arguments to pass to the tokenizer to be used to apply the chat template to the history when `assistant_only` is `True`.
687687
To control the tokenization in the actor, pass the tokenizer kwargs to the actor constructor.

torchrl/modules/llm/policies/transformers_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ def _from_transformers_generate_history(self, td, cfg, out) -> TensorDictBase:
686686
h_responses = _extract_responses_from_full_histories(
687687
text_full, prompt_histories, self.chat_template_name, self.tokenizer
688688
)
689-
history_chat_flat.response = torch.stack(h_responses)
689+
history_chat_flat.response = h_responses
690690
result.set(self.history_key, history_chat)
691691
return result
692692

torchrl/modules/llm/policies/vllm_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,7 @@ def _from_vllm_generate_history(
726726
h_responses = _extract_responses_from_full_histories(
727727
text_full, prompt_histories, self.chat_template_name, self.tokenizer
728728
)
729-
history_chat_flat.response = torch.stack(h_responses)
729+
history_chat_flat.response = h_responses
730730
result.set(self.history_key, history_chat)
731731
return result
732732

0 commit comments

Comments
 (0)