Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions learning/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,16 @@ def save_checkpoint(self):
def load_checkpoint(self):
"""
Loads the checkpoint from the specified file.

Returns:
bool: True if the checkpoint was successfully loaded, False otherwise.
"""
self.load_state_dict(T.load(self.checkpoint_file))
if os.path.exists(self.checkpoint_file):
self.load_state_dict(T.load(self.checkpoint_file, map_location=self.device))
return True
else:
print(f"Checkpoint not found: {self.checkpoint_file}")
return False


class ActorNetwork(nn.Module):
Expand Down Expand Up @@ -192,8 +200,16 @@ def save_checkpoint(self):
def load_checkpoint(self):
"""
Loads the checkpoint from the specified file.

Returns:
bool: True if the checkpoint was successfully loaded, False otherwise.
"""
self.load_state_dict(T.load(self.checkpoint_file))
if os.path.exists(self.checkpoint_file):
self.load_state_dict(T.load(self.checkpoint_file, map_location=self.device))
return True
else:
print(f"Checkpoint not found: {self.checkpoint_file}")
return False


class Agent():
Expand Down Expand Up @@ -456,13 +472,16 @@ def load_model(self):
It attempts to load the weights for the actor, critics, and target networks.
If any of the loading operations fail, it prints an error message and continues training from scratch.
"""
try:
self.actor.load_checkpoint()
self.critic_1.load_checkpoint()
self.critic_2.load_checkpoint()
self.target_actor.load_checkpoint()
self.target_critic_1.load_checkpoint()
self.target_critic_2.load_checkpoint()
print("Sucessfully loaded all the models")
except:
print("Failed to laod models. Starting from Scratch ")
successes = [
self.actor.load_checkpoint(),
self.critic_1.load_checkpoint(),
self.critic_2.load_checkpoint(),
self.target_actor.load_checkpoint(),
self.target_critic_1.load_checkpoint(),
self.target_critic_2.load_checkpoint(),
]

if all(successes):
print("Successfully loaded all the models")
else:
print("Failed to load models. Starting from scratch")