-
Notifications
You must be signed in to change notification settings - Fork 139
Description
Hello!
Thanks, first of all, for the library. It has been of great help to me!
Now, I wanted to discuss a portion of the code that I believe to be erroneous. In class RolloutGenerator
, function get_action_and_transition()
, we have the following code:
def get_action_and_transition(self, obs, hidden):
""" Get action and transition.
Encode obs to latent using the VAE, then obtain estimation for next
latent and next hidden state using the MDRNN and compute the controller
corresponding action.
:args obs: current observation (1 x 3 x 64 x 64) torch tensor
:args hidden: current hidden state (1 x 256) torch tensor
:returns: (action, next_hidden)
- action: 1D np array
- next_hidden (1 x 256) torch tensor
"""
_, latent_mu, _ = self.vae(obs)
action = self.controller(latent_mu, hidden[0])
_, _, _, _, _, next_hidden = self.mdrnn(action, latent_mu, hidden)
return action.squeeze().cpu().numpy(), next_hidden
I think this function description is quite clear. The problem is, it feeds latent_mu
to both the controller and the mdrnn network. I would argue that we should use the real latent vector instead (let's call it z
).
First, the current implementation is not what they do in the original World models paper, as they describe the controller as, and I quote:
C is a simple single layer linear model that maps z_t and h_t directly to action a_t at each time step.
Second, we train the mdrnn network using the latent vector 'z' (see file trainmdrnn.py
, function to_latent()
). Therefore, why do we use latent_mu
now?
This problem affects both the training and testing of the controller. It might be the reason why you report that the memory module is of little to no help in your experiments (https://ctallec.github.io/world-models/). However, I must say I haven't done any proper testing yet.
I would like to hear your thoughts on this.