Skip to content

Commit 6f634c6

Browse files
author
Vincent Moens
committed
[Test] Fix warnings in tests
ghstack-source-id: d4ed75d Pull Request resolved: #2886
1 parent b66fcd4 commit 6f634c6

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

test/test_cost.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@
181181
pytest.mark.filterwarnings(
182182
"ignore:dep_util is Deprecated. Use functions from setuptools instead"
183183
),
184+
pytest.mark.filterwarnings(
185+
"ignore:The PyTorch API of nested tensors is in prototype"
186+
),
184187
]
185188

186189

@@ -16679,9 +16682,13 @@ def test_hf(self, from_text):
1667916682
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
1668016683
tokenizer.pad_token = tokenizer.eos_token
1668116684

16682-
model = OPTForCausalLM(OPTConfig())
16685+
model = OPTForCausalLM(OPTConfig()).eval()
1668316686
policy_inference = TransformersWrapper(
16684-
model, tokenizer=tokenizer, generate=True, from_text=from_text
16687+
model,
16688+
tokenizer=tokenizer,
16689+
generate=True,
16690+
from_text=from_text,
16691+
return_log_probs=True,
1668516692
)
1668616693
policy_train = TransformersWrapper(
1668716694
model, tokenizer=tokenizer, generate=False, from_text=False

torchrl/envs/custom/chess.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -606,8 +606,8 @@ def _step(self, tensordict):
606606
reward = torch.tensor([reward_val], dtype=torch.float32)
607607
dest.set("reward", reward)
608608
dest.set("turn", turn)
609-
dest.set("done", [done])
610-
dest.set("terminated", [done])
609+
dest.set("done", torch.tensor([done]))
610+
dest.set("terminated", torch.tensor([done]))
611611
if self.pixels:
612612
dest.set("pixels", self._get_tensor_image(board=self.board))
613613
return dest

torchrl/objectives/ppo.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,10 @@ def _log_weight(
584584
self.tensor_keys.sample_log_prob,
585585
adv_shape,
586586
)
587-
587+
if prev_log_prob is None:
588+
raise KeyError(
589+
f"Couldn't find the log-prob {self.tensor_keys.sample_log_prob} in the input data."
590+
)
588591
if prev_log_prob.requires_grad:
589592
raise RuntimeError(
590593
f"tensordict stored {self.tensor_keys.sample_log_prob} requires grad."

0 commit comments

Comments
 (0)