Skip to content

Commit 918a4b4

Browse files
authored
5/16 Updates
1 parent 2c9c958 commit 918a4b4

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

botorch_community/acquisition/latent_information_gain.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,35 @@ def __init__(
5454
self.scaler = scaler
5555

5656
def forward(self, candidate_x: Tensor) -> Tensor:
57+
"""
58+
Conduct the Latent Information Gain acquisition function for the inputs.
59+
60+
Args:
61+
candidate_x: Candidate input points, as a Tensor. Ideally in the shape
62+
(N, q, D).
63+
64+
Returns:
65+
torch.Tensor: The LIG scores of computed KLDs, in the shape (N, q).
66+
"""
5767
device = candidate_x.device
5868
candidate_x = candidate_x.to(device)
5969
N, q, D = candidate_x.shape
60-
kl = torch.zeros(N, device=device)
70+
kl = torch.zeros(N, device=device, dtype=torch.float32)
71+
def normal_dist(mu, logvar, min_std, scaler):
72+
r"""Helper function for creating the normal distributions.
73+
74+
Args:
75+
mu: Tensor representing the Gaussian distribution mean.
76+
logvar: Tensor representing the log variance of the
77+
Gaussian distribution.
78+
min_std: Float representing the minimum standardized std.
79+
scaler: Float scaling the std.
80+
81+
Returns:
82+
torch.distributions.Normal: The normal distribution.
83+
"""
84+
std = min_std + scaler * torch.sigmoid(logvar)
85+
return torch.distributions.Normal(mu, std)
6186
if isinstance(self.model, NeuralProcessModel):
6287
x_c, y_c, _, _ = self.model.random_split_context_target(
6388
self.model.train_X, self.model.train_Y, self.model.n_context
@@ -82,15 +107,11 @@ def forward(self, candidate_x: Tensor) -> Tensor:
82107
combined_x, combined_y
83108
)
84109

85-
std_prior = self.min_std + self.scaler * torch.sigmoid(
86-
z_logvar_context
110+
p = normal_dist(z_mu_post, z_logvar_post, self.min_std, self.scaler)
111+
q = normal_dist(
112+
z_mu_context, z_logvar_context, self.min_std, self.scaler
87113
)
88-
std_post = self.min_std + self.scaler * torch.sigmoid(z_logvar_post)
89-
90-
p = torch.distributions.Normal(z_mu_post, std_post)
91-
q = torch.distributions.Normal(z_mu_context, std_prior)
92114
kl_sample = torch.distributions.kl_divergence(p, q).sum()
93-
94115
kl_i += kl_sample
95116

96117
kl[i] = kl_i / self.num_samples
@@ -108,4 +129,4 @@ def forward(self, candidate_x: Tensor) -> Tensor:
108129
).sum()
109130

110131
kl[i] = kl_i / self.num_samples
111-
return kl
132+
return kl

botorch_community/models/np_regression.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,8 @@ def __init__(
282282
activation=activation,
283283
init_func=init_func,
284284
).to(self.device)
285-
self.train_X = train_X.to(self.device)
286-
self.train_Y = train_Y.to(self.device)
285+
self.train_X = train_X
286+
self.train_Y = train_Y
287287
self.n_context = n_context
288288
self.z_dim = z_dim
289289
self.z_mu_all = None

0 commit comments

Comments
 (0)