File tree Expand file tree Collapse file tree 3 files changed +15
-5
lines changed Expand file tree Collapse file tree 3 files changed +15
-5
lines changed Original file line number Diff line number Diff line change 181
181
pytest.mark.filterwarnings(
182
182
"ignore:dep_util is Deprecated. Use functions from setuptools instead"
183
183
),
184
+ pytest.mark.filterwarnings(
185
+ "ignore:The PyTorch API of nested tensors is in prototype"
186
+ ),
184
187
]
185
188
186
189
@@ -16679,9 +16682,13 @@ def test_hf(self, from_text):
16679
16682
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
16680
16683
tokenizer.pad_token = tokenizer.eos_token
16681
16684
16682
- model = OPTForCausalLM(OPTConfig())
16685
+ model = OPTForCausalLM(OPTConfig()).eval()
16683
16686
policy_inference = TransformersWrapper(
16684
- model, tokenizer=tokenizer, generate=True, from_text=from_text
16687
+ model,
16688
+ tokenizer=tokenizer,
16689
+ generate=True,
16690
+ from_text=from_text,
16691
+ return_log_probs=True,
16685
16692
)
16686
16693
policy_train = TransformersWrapper(
16687
16694
model, tokenizer=tokenizer, generate=False, from_text=False
Original file line number Diff line number Diff line change @@ -606,8 +606,8 @@ def _step(self, tensordict):
606
606
reward = torch .tensor ([reward_val ], dtype = torch .float32 )
607
607
dest .set ("reward" , reward )
608
608
dest .set ("turn" , turn )
609
- dest .set ("done" , [done ])
610
- dest .set ("terminated" , [done ])
609
+ dest .set ("done" , torch . tensor ( [done ]) )
610
+ dest .set ("terminated" , torch . tensor ( [done ]) )
611
611
if self .pixels :
612
612
dest .set ("pixels" , self ._get_tensor_image (board = self .board ))
613
613
return dest
Original file line number Diff line number Diff line change @@ -584,7 +584,10 @@ def _log_weight(
584
584
self .tensor_keys .sample_log_prob ,
585
585
adv_shape ,
586
586
)
587
-
587
+ if prev_log_prob is None :
588
+ raise KeyError (
589
+ f"Couldn't find the log-prob { self .tensor_keys .sample_log_prob } in the input data."
590
+ )
588
591
if prev_log_prob .requires_grad :
589
592
raise RuntimeError (
590
593
f"tensordict stored { self .tensor_keys .sample_log_prob } requires grad."
You can’t perform that action at this time.
0 commit comments