Skip to content

Commit 0ceb9ca

Browse files
authored
April Updates
1 parent 0e28077 commit 0ceb9ca

File tree

1 file changed

+48
-50
lines changed

1 file changed

+48
-50
lines changed

botorch_community/acquisition/latent_information_gain.py

Lines changed: 48 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727
from torch import Tensor
2828
# reference: https://arxiv.org/abs/2106.02770
2929

30-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31-
3230

3331
class LatentInformationGain(AcquisitionFunction):
3432
def __init__(
@@ -56,58 +54,58 @@ def __init__(
5654
self.scaler = scaler
5755

5856
def forward(self, candidate_x: Tensor) -> Tensor:
59-
"""
60-
Conduct the Latent Information Gain acquisition function using the model's
61-
posterior.
62-
63-
Args:
64-
candidate_x: Candidate input points, as a Tensor. Ideally in the shape
65-
(N, q, D).
66-
67-
Returns:
68-
torch.Tensor: The LIG scores of computed KLDs, in the shape (N, q).
69-
"""
57+
device = candidate_x.device
7058
candidate_x = candidate_x.to(device)
71-
if candidate_x.dim() == 2:
72-
candidate_x = candidate_x.unsqueeze(0) # Ensure (N, q, D) format
7359
N, q, D = candidate_x.shape
74-
75-
kl = torch.zeros(N, q, device=device)
76-
60+
kl = torch.zeros(N, device=device)
7761
if isinstance(self.model, NeuralProcessModel):
78-
x_c, y_c, x_t, y_t = self.model.random_split_context_target(
79-
self.model.train_X,
80-
self.model.train_Y,
81-
self.model.n_context
62+
x_c, y_c, _, _ = self.model.random_split_context_target(
63+
self.model.train_X, self.model.train_Y, self.model.n_context
8264
)
8365
z_mu_context, z_logvar_context = self.model.data_to_z_params(x_c, y_c)
84-
for _ in range(self.num_samples):
85-
# Taking Samples/Predictions
86-
samples = self.model.sample_z(z_mu_context, z_logvar_context)
87-
y_pred = self.model.decoder(candidate_x.view(-1, D), samples)
88-
# Combining the data
89-
combined_x = torch.cat([x_c, candidate_x.view(-1, D)], dim=0).to(device)
90-
combined_y = torch.cat([y_c, y_pred], dim=0).to(device)
91-
# Computing posterior variables
92-
z_mu_posterior, z_logvar_posterior = self.model.data_to_z_params(
93-
combined_x, combined_y
94-
)
95-
std_prior = self.min_std + self.scaler * torch.sigmoid(z_logvar_context)
96-
std_posterior = self.min_std + self.scaler * torch.sigmoid(
97-
z_logvar_posterior
98-
)
99-
p = torch.distributions.Normal(z_mu_posterior, std_posterior)
100-
q = torch.distributions.Normal(z_mu_context, std_prior)
101-
kl_divergence = torch.distributions.kl_divergence(p, q).sum(dim=-1)
102-
kl += kl_divergence
103-
else:
104-
for _ in range(self.num_samples):
105-
posterior_prior = self.model.posterior(self.model.train_X)
106-
posterior_candidate = self.model.posterior(candidate_x.view(-1, D))
10766

108-
kl_divergence = torch.distributions.kl_divergence(
109-
posterior_candidate.mvn, posterior_prior.mvn
110-
).sum(dim=-1)
111-
kl += kl_divergence
67+
for i in range(N):
68+
x_i = candidate_x[i]
69+
kl_i = 0.0
11270

113-
return kl / self.num_samples
71+
for _ in range(self.num_samples):
72+
sample_z = self.model.sample_z(z_mu_context, z_logvar_context)
73+
if sample_z.dim() == 1:
74+
sample_z = sample_z.unsqueeze(0)
75+
76+
y_pred = self.model.decoder(x_i, sample_z)
77+
78+
combined_x = torch.cat([x_c, x_i], dim=0)
79+
combined_y = torch.cat([y_c, y_pred], dim=0)
80+
81+
z_mu_post, z_logvar_post = self.model.data_to_z_params(
82+
combined_x, combined_y
83+
)
84+
85+
std_prior = self.min_std + self.scaler * torch.sigmoid(
86+
z_logvar_context
87+
)
88+
std_post = self.min_std + self.scaler * torch.sigmoid(z_logvar_post)
89+
90+
p = torch.distributions.Normal(z_mu_post, std_post)
91+
q = torch.distributions.Normal(z_mu_context, std_prior)
92+
kl_sample = torch.distributions.kl_divergence(p, q).sum()
93+
94+
kl_i += kl_sample
95+
96+
kl[i] = kl_i / self.num_samples
97+
98+
else:
99+
for i in range(N):
100+
x_i = candidate_x[i]
101+
kl_i = 0.0
102+
for _ in range(self.num_samples):
103+
posterior_prior = self.model.posterior(self.model.train_X)
104+
posterior_candidate = self.model.posterior(x_i)
105+
106+
kl_i += torch.distributions.kl_divergence(
107+
posterior_candidate.mvn, posterior_prior.mvn
108+
).sum()
109+
110+
kl[i] = kl_i / self.num_samples
111+
return kl

0 commit comments

Comments
 (0)