Skip to content

Commit fe75f43

Browse files
authored
Recent Fixes
1 parent 918a4b4 commit fe75f43

File tree

2 files changed

+16
-34
lines changed

2 files changed

+16
-34
lines changed

botorch_community/acquisition/latent_information_gain.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -68,33 +68,23 @@ def forward(self, candidate_x: Tensor) -> Tensor:
6868
candidate_x = candidate_x.to(device)
6969
N, q, D = candidate_x.shape
7070
kl = torch.zeros(N, device=device, dtype=torch.float32)
71-
def normal_dist(mu, logvar, min_std, scaler):
72-
r"""Helper function for creating the normal distributions.
73-
74-
Args:
75-
mu: Tensor representing the Gaussian distribution mean.
76-
logvar: Tensor representing the log variance of the
77-
Gaussian distribution.
78-
min_std: Float representing the minimum standardized std.
79-
scaler: Float scaling the std.
80-
81-
Returns:
82-
torch.distributions.Normal: The normal distribution.
83-
"""
84-
std = min_std + scaler * torch.sigmoid(logvar)
85-
return torch.distributions.Normal(mu, std)
71+
8672
if isinstance(self.model, NeuralProcessModel):
8773
x_c, y_c, _, _ = self.model.random_split_context_target(
8874
self.model.train_X, self.model.train_Y, self.model.n_context
8975
)
90-
z_mu_context, z_logvar_context = self.model.data_to_z_params(x_c, y_c)
76+
self.model.z_mu_context, self.model.z_logvar_context = (
77+
self.model.data_to_z_params(x_c, y_c)
78+
)
9179

9280
for i in range(N):
9381
x_i = candidate_x[i]
9482
kl_i = 0.0
9583

9684
for _ in range(self.num_samples):
97-
sample_z = self.model.sample_z(z_mu_context, z_logvar_context)
85+
sample_z = self.model.sample_z(
86+
self.model.z_mu_context, self.model.z_logvar_context
87+
)
9888
if sample_z.dim() == 1:
9989
sample_z = sample_z.unsqueeze(0)
10090

@@ -103,15 +93,10 @@ def normal_dist(mu, logvar, min_std, scaler):
10393
combined_x = torch.cat([x_c, x_i], dim=0)
10494
combined_y = torch.cat([y_c, y_pred], dim=0)
10595

106-
z_mu_post, z_logvar_post = self.model.data_to_z_params(
107-
combined_x, combined_y
108-
)
109-
110-
p = normal_dist(z_mu_post, z_logvar_post, self.min_std, self.scaler)
111-
q = normal_dist(
112-
z_mu_context, z_logvar_context, self.min_std, self.scaler
96+
self.model.z_mu_all, self.model.z_logvar_all = (
97+
self.model.data_to_z_params(combined_x, combined_y)
11398
)
114-
kl_sample = torch.distributions.kl_divergence(p, q).sum()
99+
kl_sample = self.model.KLD_gaussian(self.min_std, self.scaler)
115100
kl_i += kl_sample
116101

117102
kl[i] = kl_i / self.num_samples
@@ -129,4 +114,4 @@ def normal_dist(mu, logvar, min_std, scaler):
129114
).sum()
130115

131116
kl[i] = kl_i / self.num_samples
132-
return kl
117+
return kl

botorch_community/models/np_regression.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -264,24 +264,23 @@ def __init__(
264264
super().__init__()
265265
self.device = train_X.device
266266

267-
# self._validate_tensor_args(X=train_X, Y=train_Y)
268267
self.r_encoder = REncoder(
269268
x_dim + y_dim,
270269
r_dim,
271270
r_hidden_dims,
272271
activation=activation,
273272
init_func=init_func,
274-
).to(self.device)
273+
)
275274
self.z_encoder = ZEncoder(
276275
r_dim, z_dim, z_hidden_dims, activation=activation, init_func=init_func
277-
).to(self.device)
276+
)
278277
self.decoder = Decoder(
279278
x_dim + z_dim,
280279
y_dim,
281280
decoder_hidden_dims,
282281
activation=activation,
283282
init_func=init_func,
284-
).to(self.device)
283+
)
285284
self.train_X = train_X
286285
self.train_Y = train_Y
287286
self.n_context = n_context
@@ -290,11 +289,9 @@ def __init__(
290289
self.z_logvar_all = None
291290
self.z_mu_context = None
292291
self.z_logvar_context = None
293-
if likelihood is None:
294-
self.likelihood = GaussianLikelihood().to(self.device)
295-
else:
296-
self.likelihood = likelihood.to(self.device)
292+
self.likelihood = likelihood if likelihood is not None else GaussianLikelihood()
297293
self.input_transform = input_transform
294+
self.to(device=self.device)
298295

299296
def data_to_z_params(
300297
self, x: torch.Tensor, y: torch.Tensor, r_dim: int = 0

0 commit comments

Comments
 (0)