From 4441fb72fea9962addf3953dc95ce4bc1ddeb63d Mon Sep 17 00:00:00 2001 From: Felix Yu Date: Mon, 7 Jul 2025 19:36:31 +0000 Subject: [PATCH] Add reference to policy with state dict --- torchrl/collectors/collectors.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 0d76124b73f..bad334d8b32 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -689,6 +689,10 @@ def __init__( policy = RandomPolicy(env.full_action_spec) elif policy_factory is not None: raise TypeError("policy_factory cannot be used with policy argument.") + # If the underlying policy has a state_dict, we keep a reference to the policy and + # do all policy weight saving/loading through it + if hasattr(policy, "state_dict"): + self._policy_w_state_dict = policy if trust_policy is None: trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule)) @@ -1681,8 +1685,8 @@ def state_dict(self) -> OrderedDict: else: env_state_dict = OrderedDict() - if hasattr(self.policy, "state_dict"): - policy_state_dict = self.policy.state_dict() + if hasattr(self, "_policy_w_state_dict"): + policy_state_dict = self._policy_w_state_dict.state_dict() state_dict = OrderedDict( policy_state_dict=policy_state_dict, env_state_dict=env_state_dict, @@ -1706,7 +1710,13 @@ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None: if strict or "env_state_dict" in state_dict: self.env.load_state_dict(state_dict["env_state_dict"], **kwargs) if strict or "policy_state_dict" in state_dict: - self.policy.load_state_dict(state_dict["policy_state_dict"], **kwargs) + if not hasattr(self, "_policy_w_state_dict"): + raise ValueError( + "Underlying policy does not have state_dict to load policy_state_dict into." + ) + self._policy_w_state_dict.load_state_dict( + state_dict["policy_state_dict"], **kwargs + ) self._frames = state_dict["frames"] self._iter = state_dict["iter"]