Skip to content

Commit 165163a

Browse files
author
Vincent Moens
committed
[BugFix] Fix tests failing because of pytorch/pytorch#137602
cc mikaylagawarecki albanD ghstack-source-id: 6fc7434 Pull Request resolved: #2558
1 parent 50a35f6 commit 165163a

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

test/test_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def make_storage():
450450
rb_trainer2.register(trainer2)
451451
if re_init:
452452
trainer2._process_batch_hook(td.to_tensordict().zero_())
453-
trainer2.load_from_file(file)
453+
trainer2.load_from_file(file, weights_only=False)
454454
assert state_dict_has_been_called[0]
455455
assert load_state_dict_has_been_called[0]
456456
assert state_dict_has_been_called_td[0]

torchrl/trainers/trainers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,12 +296,17 @@ def save_trainer(self, force_save: bool = False) -> None:
296296
if _save and self.save_trainer_file:
297297
self._save_trainer()
298298

299-
def load_from_file(self, file: Union[str, pathlib.Path]) -> Trainer:
299+
def load_from_file(self, file: Union[str, pathlib.Path], **kwargs) -> Trainer:
300+
"""Loads a file and its state-dict in the trainer.
301+
302+
Keyword arguments are passed to the :func:`~torch.load` function.
303+
304+
"""
300305
if _CKPT_BACKEND == "torchsnapshot":
301306
snapshot = Snapshot(path=file)
302307
snapshot.restore(app_state=self.app_state)
303308
elif _CKPT_BACKEND == "torch":
304-
loaded_dict: OrderedDict = torch.load(file)
309+
loaded_dict: OrderedDict = torch.load(file, **kwargs)
305310
self.load_state_dict(loaded_dict)
306311
return self
307312

0 commit comments

Comments
 (0)