-
Notifications
You must be signed in to change notification settings - Fork 83
Open
Description
I'm recently starting to re-implement PILCO in pytorch for better intergration with my other works. To leverage the fast prediction (KISS-GP) in gpytorch, I decided to use MCMC sampling approach to implement the core function in mgpr.py - predict_on_noisy_input which use moment matching based on the original paper.
However, the result I got from sampling is dramatically different from moment matching. I wonder if anyone can help me identify the problem. The following code shows both optimize and predict_on_noisy_input.
def optimize(self,restarts=1, training_iter = 200):
self.likelihood.train()
self.model.train()
# Use the adam optimizer
optimizer = torch.optim.Adam([
{'params': self.model.parameters()}, # Includes GaussianLikelihood parameters
], lr=self.lr)
# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self.model)
for i in range(training_iter):
# Zero gradients from previous iteration
optimizer.zero_grad()
# Output from model
output = self.model(self.X)
# Calc loss and backprop gradients
loss = -mll(output, self.Y).sum()
loss.backward()
print('Iter %d/%d - Loss: %.3f' % (i + 1, training_iter, loss.item()))
optimizer.step()
def predict_on_noisy_inputs(self, m, s, num_samps=500):
"""
Approximate GP regression at noisy inputs via moment matching
IN: mean (m) (row vector) and (s) variance of the state
OUT: mean (M) (row vector), variance (S) of the action
and inv(s)*input-ouputcovariance
We adopt the sampling approach by leveraging the power of GPU
"""
assert(m.shape[1] == self.num_dims and s.shape == (self.num_dims,self.num_dims))
self.likelihood.eval()
self.model.eval()
if self.cuda == True:
m = torch.tensor(m).float().cuda()
s = torch.tensor(s).float().cuda()
inv_s = torch.inverse(s)
sample_model = torch.distributions.MultivariateNormal(m,s)
pred_inputs = sample_model.sample((num_samps,)).float()
pred_inputs[pred_inputs != pred_inputs] = 0
pred_inputs,_ = torch.sort(pred_inputs,dim=0)
pred_inputs = pred_inputs.reshape(num_samps,self.num_dims).repeat(self.num_outputs,1,1)
#centralize X ?
# self.model.set_train_data(self.centralized_input(m),self.Y)
with torch.no_grad(), gpytorch.settings.fast_pred_var():
pred_outputs = self.model(pred_inputs)
#Calculate mean, variance and inv(s)* input-output covariance
M = torch.mean(pred_outputs.mean,1)[None,:]
V_ = torch.cat((pred_inputs[0].t(),pred_outputs.mean),0)
fact = 1.0 / (V_.size(1) - 1)
V_ -= torch.mean(V_, dim=1, keepdim=True)
V_t = V_.t() # if complex: mt = m.t().conj()
covs = fact * V_.matmul(V_t).squeeze()
V = covs[0:self.num_dims,self.num_dims:]
V = inv_s @ V
S = covs[self.num_dims:,self.num_dims:]
return M, S, VMetadata
Metadata
Assignees
Labels
No labels