Skip to content

Commit 01f172a

Browse files
committed
Update
[ghstack-poisoned]
1 parent cb64cd5 commit 01f172a

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

torchrl/modules/llm/vllm_policy.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
TensorDictModuleBase,
2222
TensorDictSequential as Seq,
2323
)
24+
from tensordict.utils import _zip_strict
2425

2526
from torchrl.data import LLMData
2627

@@ -367,8 +368,8 @@ def get_logprob(output):
367368
self.prompt_logprobs = torch.tensor(
368369
[
369370
v[tid].logprob if v is not None else 0.0
370-
for v, tid in zip(
371-
self.prompt_logprobs, self.prompt_token_ids, strict=True
371+
for v, tid in _zip_strict(
372+
self.prompt_logprobs, self.prompt_token_ids
372373
)
373374
]
374375
)

0 commit comments

Comments
 (0)