Skip to content

Commit 0e28077

Browse files
authored
April Updates
1 parent 529e36c commit 0e28077

File tree

1 file changed

+53
-51
lines changed

1 file changed

+53
-51
lines changed

botorch_community/models/np_regression.py

Lines changed: 53 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,9 @@
2222
from gpytorch.distributions import MultivariateNormal
2323
from gpytorch.likelihoods import GaussianLikelihood
2424
from gpytorch.likelihoods.likelihood import Likelihood
25+
from gpytorch.models.gp import GP
2526
from torch.nn import Module
2627

27-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28-
# Account for different acquisitions
29-
3028

3129
# reference: https://chrisorm.github.io/NGP.html
3230
class MLP(nn.Module):
@@ -56,21 +54,21 @@ def __init__(
5654
prev_dim = input_dim
5755

5856
for hidden_dim in hidden_dims:
59-
layer = nn.Linear(prev_dim, hidden_dim).to(device)
57+
layer = nn.Linear(prev_dim, hidden_dim)
6058
if init_func is not None:
6159
init_func(layer.weight)
6260
layers.append(layer)
6361
layers.append(activation())
6462
prev_dim = hidden_dim
6563

66-
final_layer = nn.Linear(prev_dim, output_dim).to(device)
64+
final_layer = nn.Linear(prev_dim, output_dim)
6765
if init_func is not None:
6866
init_func(final_layer.weight)
6967
layers.append(final_layer)
70-
self.model = nn.Sequential(*layers).to(device)
68+
self.model = nn.Sequential(*layers)
7169

7270
def forward(self, x: torch.Tensor) -> torch.Tensor:
73-
return self.model(x.to(device))
71+
return self.model(x)
7472

7573

7674
class REncoder(nn.Module):
@@ -101,7 +99,7 @@ def __init__(
10199
hidden_dims=hidden_dims,
102100
activation=activation,
103101
init_func=init_func,
104-
).to(device)
102+
)
105103

106104
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
107105
r"""Forward pass for representation encoder.
@@ -112,7 +110,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
112110
Returns:
113111
torch.Tensor: Encoded representations
114112
"""
115-
return self.mlp(inputs.to(device))
113+
return self.mlp(inputs)
116114

117115

118116
class ZEncoder(nn.Module):
@@ -144,14 +142,14 @@ def __init__(
144142
hidden_dims=hidden_dims,
145143
activation=activation,
146144
init_func=init_func,
147-
).to(device)
145+
)
148146
self.logvar_net = MLP(
149147
input_dim=input_dim,
150148
output_dim=output_dim,
151149
hidden_dims=hidden_dims,
152150
activation=activation,
153151
init_func=init_func,
154-
).to(device)
152+
)
155153

156154
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
157155
r"""Forward pass for latent encoder.
@@ -164,7 +162,6 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
164162
- Mean of the latent Gaussian distribution.
165163
- Log variance of the latent Gaussian distribution.
166164
"""
167-
inputs = inputs.to(device)
168165
return self.mean_net(inputs), self.logvar_net(inputs)
169166

170167

@@ -197,7 +194,7 @@ def __init__(
197194
hidden_dims=hidden_dims,
198195
activation=activation,
199196
init_func=init_func,
200-
).to(device)
197+
)
201198

202199
def forward(self, x_pred: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
203200
r"""Forward pass for decoder.
@@ -212,14 +209,17 @@ def forward(self, x_pred: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
212209
torch.Tensor: Predicted target values of shape (n x z_dim), representing #
213210
of data points by z_dim.
214211
"""
215-
z = z.to(device)
216-
z_expanded = z.unsqueeze(0).expand(x_pred.size(0), -1).to(device)
217-
x_pred = x_pred.to(device)
212+
if z.dim() == 1:
213+
z = z.unsqueeze(0)
214+
if z.dim() == 3:
215+
z = z.squeeze(0)
216+
z_expanded = z.expand(x_pred.size(0), -1)
217+
x_pred = x_pred
218218
xz = torch.cat([x_pred, z_expanded], dim=-1)
219219
return self.mlp(xz)
220220

221221

222-
class NeuralProcessModel(Model):
222+
class NeuralProcessModel(Model, GP):
223223
def __init__(
224224
self,
225225
train_X: torch.Tensor,
@@ -262,35 +262,38 @@ def __init__(
262262
forward pass.
263263
"""
264264
super().__init__()
265+
self.device = train_X.device
266+
267+
# self._validate_tensor_args(X=train_X, Y=train_Y)
265268
self.r_encoder = REncoder(
266269
x_dim + y_dim,
267270
r_dim,
268271
r_hidden_dims,
269272
activation=activation,
270273
init_func=init_func,
271-
).to(device)
274+
).to(self.device)
272275
self.z_encoder = ZEncoder(
273276
r_dim, z_dim, z_hidden_dims, activation=activation, init_func=init_func
274-
).to(device)
277+
).to(self.device)
275278
self.decoder = Decoder(
276279
x_dim + z_dim,
277280
y_dim,
278281
decoder_hidden_dims,
279282
activation=activation,
280283
init_func=init_func,
281-
).to(device)
282-
self.train_X = train_X.to(device)
283-
self.train_Y = train_Y.to(device)
284+
).to(self.device)
285+
self.train_X = train_X.to(self.device)
286+
self.train_Y = train_Y.to(self.device)
284287
self.n_context = n_context
285288
self.z_dim = z_dim
286289
self.z_mu_all = None
287290
self.z_logvar_all = None
288291
self.z_mu_context = None
289292
self.z_logvar_context = None
290293
if likelihood is None:
291-
self.likelihood = GaussianLikelihood().to(device)
294+
self.likelihood = GaussianLikelihood().to(self.device)
292295
else:
293-
self.likelihood = likelihood.to(device)
296+
self.likelihood = likelihood.to(self.device)
294297
self.input_transform = input_transform
295298

296299
def data_to_z_params(
@@ -310,11 +313,11 @@ def data_to_z_params(
310313
- x_t: Target input data.
311314
- y_t: Target target data.
312315
"""
313-
x = x.to(device)
314-
y = y.to(device)
315-
xy = torch.cat([x, y], dim=-1).to(device).to(device)
316+
x = x.to(self.device)
317+
y = y.to(self.device)
318+
xy = torch.cat([x, y], dim=-1).to(self.device).to(self.device)
316319
rs = self.r_encoder(xy)
317-
r_agg = rs.mean(dim=r_dim).to(device)
320+
r_agg = rs.mean(dim=r_dim).to(self.device)
318321
return self.z_encoder(r_agg)
319322

320323
def sample_z(
@@ -344,11 +347,11 @@ def sample_z(
344347
shape = [n, self.z_dim]
345348
if n == 1:
346349
shape = shape[1:]
347-
eps = torch.autograd.Variable(logvar.data.new(*shape).normal_()).to(device)
350+
eps = torch.autograd.Variable(logvar.data.new(*shape).normal_()).to(self.device)
348351

349352
std = min_std + scaler * torch.sigmoid(logvar)
350-
std = std.to(device)
351-
mu = mu.to(device)
353+
std = std.to(self.device)
354+
mu = mu.to(self.device)
352355
return mu + std * eps
353356

354357
def KLD_gaussian(self, min_std: float = 0.01, scaler: float = 0.5) -> torch.Tensor:
@@ -365,10 +368,10 @@ def KLD_gaussian(self, min_std: float = 0.01, scaler: float = 0.5) -> torch.Tens
365368

366369
if min_std <= 0 or scaler <= 0:
367370
raise ValueError()
368-
std_q = min_std + scaler * torch.sigmoid(self.z_logvar_all).to(device)
369-
std_p = min_std + scaler * torch.sigmoid(self.z_logvar_context).to(device)
370-
p = torch.distributions.Normal(self.z_mu_context.to(device), std_p)
371-
q = torch.distributions.Normal(self.z_mu_all.to(device), std_q)
371+
std_q = min_std + scaler * torch.sigmoid(self.z_logvar_all).to(self.device)
372+
std_p = min_std + scaler * torch.sigmoid(self.z_logvar_context).to(self.device)
373+
p = torch.distributions.Normal(self.z_mu_context.to(self.device), std_p)
374+
q = torch.distributions.Normal(self.z_mu_all.to(self.device), std_q)
372375
return torch.distributions.kl_divergence(p, q).sum()
373376

374377
def posterior(
@@ -378,7 +381,7 @@ def posterior(
378381
observation_noise: bool = False,
379382
posterior_transform: PosteriorTransform | None = None,
380383
) -> GPyTorchPosterior:
381-
r"""Computes the model's posterior distribution for given input tensors.
384+
r"""Computes the model's posterior for given input tensors.
382385
383386
Args:
384387
X: Input Tensor
@@ -391,20 +394,19 @@ def posterior(
391394
defaults to None.
392395
393396
Returns:
394-
GPyTorchPosterior: The posterior distribution object
395-
utilizing MultivariateNormal.
397+
GPyTorchPosterior: The posterior utilizing MultivariateNormal.
396398
"""
397399
X = self.transform_inputs(X)
398-
X = X.to(device)
400+
X = X.to(self.device)
399401
mean = self.decoder(
400-
X.to(device), self.sample_z(self.z_mu_all, self.z_logvar_all)
402+
X.to(self.device), self.sample_z(self.z_mu_all, self.z_logvar_all)
401403
)
402404
z_var = torch.exp(self.z_logvar_all)
403-
covariance = torch.eye(X.size(0)).to(device) * z_var.mean()
405+
covariance = torch.eye(X.size(0)).to(self.device) * z_var.mean()
404406
if observation_noise:
405407
covariance = covariance + self.likelihood.noise * torch.eye(
406408
covariance.size(0)
407-
).to(device)
409+
).to(self.device)
408410
mvn = MultivariateNormal(mean, covariance)
409411
posterior = GPyTorchPosterior(mvn)
410412
if posterior_transform is not None:
@@ -425,7 +427,7 @@ def transform_inputs(
425427
Returns:
426428
torch.Tensor: A tensor of transformed inputs
427429
"""
428-
X = X.to(device)
430+
X = X.to(self.device)
429431
if input_transform is not None:
430432
input_transform.to(X)
431433
return input_transform(X)
@@ -454,10 +456,10 @@ def forward(
454456
x_c, y_c, x_t, y_t = self.random_split_context_target(
455457
train_X, train_Y, self.n_context, axis=axis
456458
)
457-
x_t = x_t.to(device)
458-
x_c = x_c.to(device)
459-
y_c = y_c.to(device)
460-
y_t = y_t.to(device)
459+
x_t = x_t.to(self.device)
460+
x_c = x_c.to(self.device)
461+
y_c = y_c.to(self.device)
462+
y_t = y_t.to(self.device)
461463
self.z_mu_all, self.z_logvar_all = self.data_to_z_params(
462464
self.train_X, self.train_Y
463465
)
@@ -486,9 +488,9 @@ def random_split_context_target(
486488
self.n_context = n_context
487489
mask = torch.randperm(x.shape[axis])[:n_context]
488490
splitter = torch.zeros(x.shape[axis], dtype=torch.bool)
489-
x_c = x[mask].to(device)
490-
y_c = y[mask].to(device)
491+
x_c = x[mask].to(self.device)
492+
y_c = y[mask].to(self.device)
491493
splitter[mask] = True
492-
x_t = x[~splitter].to(device)
493-
y_t = y[~splitter].to(device)
494+
x_t = x[~splitter].to(self.device)
495+
y_t = y[~splitter].to(self.device)
494496
return x_c, y_c, x_t, y_t

0 commit comments

Comments
 (0)