Skip to content

Commit 280776d

Browse files
authored
1/25 Updates
1 parent 4f35e0f commit 280776d

File tree

1 file changed

+55
-77
lines changed

1 file changed

+55
-77
lines changed

botorch_community/models/np_regression.py

Lines changed: 55 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,16 @@
1111
Contributor: eibarolle
1212
"""
1313

14-
import copy
15-
import numpy as np
16-
from numpy.random import binomial
1714
import torch
1815
import torch.nn as nn
19-
import matplotlib.pyplot as plts
20-
# %matplotlib inline
2116
from botorch.models.model import Model
2217
from botorch.posteriors import GPyTorchPosterior
2318
from botorch.acquisition.objective import PosteriorTransform
24-
from sklearn.gaussian_process import GaussianProcessRegressor
25-
from sklearn.gaussian_process.kernels import (RBF, Matern, RationalQuadratic,
26-
ExpSineSquared, DotProduct,
27-
ConstantKernel)
2819
from typing import Callable, List, Optional, Tuple
29-
from torch.nn import Module, ModuleDict, ModuleList
30-
from sklearn import preprocessing
31-
from scipy.stats import multivariate_normal
20+
from torch.nn import Module
3221
from gpytorch.distributions import MultivariateNormal
3322

34-
device = torch.device("cpu")
23+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3524
# Account for different acquisitions
3625

3726
#reference: https://chrisorm.github.io/NGP.html
@@ -59,21 +48,21 @@ def __init__(
5948
prev_dim = input_dim
6049

6150
for hidden_dim in hidden_dims:
62-
layer = nn.Linear(prev_dim, hidden_dim)
51+
layer = nn.Linear(prev_dim, hidden_dim).to(device)
6352
if init_func is not None:
6453
init_func(layer.weight)
6554
layers.append(layer)
6655
layers.append(activation())
6756
prev_dim = hidden_dim
6857

69-
final_layer = nn.Linear(prev_dim, output_dim)
58+
final_layer = nn.Linear(prev_dim, output_dim).to(device)
7059
if init_func is not None:
7160
init_func(final_layer.weight)
7261
layers.append(final_layer)
73-
self.model = nn.Sequential(*layers)
62+
self.model = nn.Sequential(*layers).to(device)
7463

7564
def forward(self, x: torch.Tensor) -> torch.Tensor:
76-
return self.model(x)
65+
return self.model(x.to(device))
7766

7867

7968
class REncoder(nn.Module):
@@ -95,12 +84,9 @@ def __init__(
9584
init_func: A function initializing the weights, defaults to nn.init.normal_.
9685
"""
9786
super().__init__()
98-
self.mlp = MLP(input_dim, output_dim, hidden_dims, activation=activation, init_func=init_func)
87+
self.mlp = MLP(input_dim=input_dim, output_dim=output_dim, hidden_dims=hidden_dims, activation=activation, init_func=init_func).to(device)
9988

100-
def forward(
101-
self,
102-
inputs: torch.Tensor,
103-
) -> torch.Tensor:
89+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
10490
r"""Forward pass for representation encoder.
10591
10692
Args:
@@ -109,7 +95,7 @@ def forward(
10995
Returns:
11096
torch.Tensor: Encoded representations
11197
"""
112-
return self.mlp(inputs)
98+
return self.mlp(inputs.to(device))
11399

114100
class ZEncoder(nn.Module):
115101
def __init__(self,
@@ -130,13 +116,10 @@ def __init__(self,
130116
init_func: A function initializing the weights, defaults to nn.init.normal_.
131117
"""
132118
super().__init__()
133-
self.mean_net = MLP(input_dim, output_dim, hidden_dims, activation=activation, init_func=init_func)
134-
self.logvar_net = MLP(input_dim, output_dim, hidden_dims, activation=activation, init_func=init_func)
119+
self.mean_net = MLP(input_dim=input_dim, output_dim=output_dim, hidden_dims=hidden_dims, activation=activation, init_func=init_func).to(device)
120+
self.logvar_net = MLP(input_dim=input_dim, output_dim=output_dim, hidden_dims=hidden_dims, activation=activation, init_func=init_func).to(device)
135121

136-
def forward(
137-
self,
138-
inputs: torch.Tensor,
139-
) -> torch.Tensor:
122+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
140123
r"""Forward pass for latent encoder.
141124
142125
Args:
@@ -147,6 +130,7 @@ def forward(
147130
- Mean of the latent Gaussian distribution.
148131
- Log variance of the latent Gaussian distribution.
149132
"""
133+
inputs = inputs.to(device)
150134
return self.mean_net(inputs), self.logvar_net(inputs)
151135

152136
class Decoder(torch.nn.Module):
@@ -168,23 +152,21 @@ def __init__(
168152
init_func: A function initializing the weights, defaults to nn.init.normal_.
169153
"""
170154
super().__init__()
171-
self.mlp = MLP(input_dim, output_dim, hidden_dims, activation=activation, init_func=init_func)
155+
self.mlp = MLP(input_dim=input_dim, output_dim=output_dim, hidden_dims=hidden_dims, activation=activation, init_func=init_func).to(device)
172156

173-
def forward(
174-
self,
175-
x_pred: torch.Tensor,
176-
z: torch.Tensor,
177-
) -> torch.Tensor:
157+
def forward(self, x_pred: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
178158
r"""Forward pass for decoder.
179159
180160
Args:
181-
x_pred: No. of data points, by x_dim
182-
z: No. of samples, by z_dim
161+
x_pred: Input points of shape (n x d_x), representing # of data points by x_dim.
162+
z: Latent encoding of shape (num_samples x d_z), representing # of samples by z_dim.
183163
184164
Returns:
185-
torch.Tensor: Predicted target values.
165+
torch.Tensor: Predicted target values of shape (n, z_dim), representing # of data points by z_dim.
186166
"""
187-
z_expanded = z.unsqueeze(0).expand(x_pred.size(0), -1)
167+
z = z.to(device)
168+
z_expanded = z.unsqueeze(0).expand(x_pred.size(0), -1).to(device)
169+
x_pred = x_pred.to(device)
188170
xz = torch.cat([x_pred, z_expanded], dim=-1)
189171
return self.mlp(xz)
190172

@@ -231,16 +213,14 @@ def __init__(
231213
init_func: A function initializing the weights, defaults to nn.init.normal_.
232214
"""
233215
super().__init__()
234-
self.r_encoder = REncoder(x_dim+y_dim, r_dim, r_hidden_dims, activation=activation, init_func=init_func)
235-
self.z_encoder = ZEncoder(r_dim, z_dim, z_hidden_dims, activation=activation, init_func=init_func)
236-
self.decoder = Decoder(x_dim + z_dim, y_dim, decoder_hidden_dims, activation=activation, init_func=init_func)
216+
self.r_encoder = REncoder(x_dim+y_dim, r_dim, r_hidden_dims, activation=activation, init_func=init_func).to(device)
217+
self.z_encoder = ZEncoder(r_dim, z_dim, z_hidden_dims, activation=activation, init_func=init_func).to(device)
218+
self.decoder = Decoder(x_dim + z_dim, y_dim, decoder_hidden_dims, activation=activation, init_func=init_func).to(device)
237219
self.z_dim = z_dim
238220
self.z_mu_all = None
239221
self.z_logvar_all = None
240222
self.z_mu_context = None
241223
self.z_logvar_context = None
242-
# self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) # Look at BoTorch native versions
243-
#self.train(n_epochs, x_train, y_train)
244224

245225
def data_to_z_params(
246226
self,
@@ -264,18 +244,20 @@ def data_to_z_params(
264244
- x_t: Target input data.
265245
- y_t: Target target data.
266246
"""
267-
xy = torch.cat([x,y], dim=xy_dim)
247+
x = x.to(device)
248+
y = y.to(device)
249+
xy = torch.cat([x,y], dim=xy_dim).to(device).to(device)
268250
rs = self.r_encoder(xy)
269-
r_agg = rs.mean(dim=r_dim)
251+
r_agg = rs.mean(dim=r_dim).to(device)
270252
return self.z_encoder(r_agg)
271253

272254
def sample_z(
273255
self,
274256
mu: torch.Tensor,
275257
logvar: torch.Tensor,
276258
n: int = 1,
277-
min_std: float = 0.1,
278-
scaler: float = 0.9
259+
min_std: float = 0.01,
260+
scaler: float = 0.5
279261
) -> torch.Tensor:
280262
r"""Reparameterization trick for z's latent distribution.
281263
@@ -291,12 +273,15 @@ def sample_z(
291273
"""
292274
if min_std <= 0 or scaler <= 0:
293275
raise ValueError()
276+
277+
shape = [n, self.z_dim]
294278
if n == 1:
295-
eps = torch.autograd.Variable(logvar.data.new(self.z_dim).normal_()).to(device)
296-
else:
297-
eps = torch.autograd.Variable(logvar.data.new(n,self.z_dim).normal_()).to(device)
279+
shape = shape[1:]
280+
eps = torch.autograd.Variable(logvar.data.new(*shape).normal_()).to(device)
298281

299282
std = min_std + scaler * torch.sigmoid(logvar)
283+
std = std.to(device)
284+
mu = mu.to(device)
300285
return mu + std * eps
301286

302287
def KLD_gaussian(
@@ -316,10 +301,10 @@ def KLD_gaussian(
316301

317302
if min_std <= 0 or scaler <= 0:
318303
raise ValueError()
319-
std_q = min_std + scaler * torch.sigmoid(self.z_logvar_all)
320-
std_p = min_std + scaler * torch.sigmoid(self.z_logvar_context)
321-
p = torch.distributions.Normal(self.z_mu_context, std_p)
322-
q = torch.distributions.Normal(self.z_mu_all, std_q)
304+
std_q = min_std + scaler * torch.sigmoid(self.z_logvar_all).to(device)
305+
std_p = min_std + scaler * torch.sigmoid(self.z_logvar_context).to(device)
306+
p = torch.distributions.Normal(self.z_mu_context.to(device), std_p)
307+
q = torch.distributions.Normal(self.z_mu_all.to(device), std_q)
323308
return torch.distributions.kl_divergence(p, q).sum()
324309

325310
def posterior(
@@ -343,7 +328,8 @@ def posterior(
343328
GPyTorchPosterior: The posterior distribution object
344329
utilizing MultivariateNormal.
345330
"""
346-
mean = self.decoder(X, self.sample_z(self.z_mu_all, self.z_logvar_all))
331+
X = X.to(device)
332+
mean = self.decoder(X.to(device), self.sample_z(self.z_mu_all, self.z_logvar_all))
347333
covariance = torch.eye(X.size(0)) * covariance_multiplier
348334
if (observation_noise):
349335
covariance = covariance + observation_constant
@@ -352,20 +338,6 @@ def posterior(
352338
if posterior_transform is not None:
353339
posterior = posterior_transform(posterior)
354340
return posterior
355-
356-
def load_state_dict(
357-
self,
358-
state_dict: dict,
359-
strict: bool = True
360-
) -> None:
361-
"""
362-
Initialize the fully Bayesian model before loading the state dict.
363-
364-
Args:
365-
state_dict (dict): A dictionary containing the parameters.
366-
strict (bool): Case matching strictness.
367-
"""
368-
super().load_state_dict(state_dict, strict=strict)
369341

370342
def transform_inputs(
371343
self,
@@ -381,6 +353,7 @@ def transform_inputs(
381353
Returns:
382354
torch.Tensor: A tensor of transformed inputs
383355
"""
356+
X = X.to(device)
384357
if input_transform is not None:
385358
input_transform.to(X)
386359
return input_transform(X)
@@ -420,6 +393,11 @@ def forward(
420393
if y_c.size(1 - target_dim) != y_t.size(1 - target_dim):
421394
raise ValueError()
422395

396+
x_t = x_t.to(device)
397+
x_c = x_c.to(device)
398+
y_c = y_c.to(device)
399+
y_t = y_t.to(device)
400+
423401
self.z_mu_all, self.z_logvar_all = self.data_to_z_params(torch.cat([x_c, x_t], dim = input_dim), torch.cat([y_c, y_t], dim = target_dim))
424402
self.z_mu_context, self.z_logvar_context = self.data_to_z_params(x_c, y_c)
425403
z = self.sample_z(self.z_mu_all, self.z_logvar_all)
@@ -447,12 +425,12 @@ def random_split_context_target(
447425
- x_t: Target input data.
448426
- y_t: Target target data.
449427
"""
450-
ind = np.arange(x.shape[0])
451-
mask = np.random.choice(ind, size=n_context, replace=False)
452-
x_c = torch.from_numpy(x[mask])
453-
y_c = torch.from_numpy(y[mask])
454-
x_t = torch.from_numpy(np.delete(x, mask, axis=0))
455-
y_t = torch.from_numpy(np.delete(y, mask, axis=0))
456-
428+
mask = torch.randperm(x.shape[0])[:n_context]
429+
x_c = torch.from_numpy(x[mask]).to(device)
430+
y_c = torch.from_numpy(y[mask]).to(device)
431+
splitter = torch.zeros(x.shape[0], dtype=torch.bool)
432+
splitter[mask] = True
433+
x_t = torch.from_numpy(x[~splitter]).to(device)
434+
y_t = torch.from_numpy(y[~splitter]).to(device)
457435
return x_c, y_c, x_t, y_t
458436

0 commit comments

Comments
 (0)