Skip to content

Commit 4aeeeeb

Browse files
authored
Update Acquisition Dimensions
1 parent 9684e1d commit 4aeeeeb

File tree

1 file changed

+11
-18
lines changed

1 file changed

+11
-18
lines changed

botorch_community/acquisition/latent_information_gain.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -62,43 +62,36 @@ def __init__(
6262
self.context_x = context_x.to(device)
6363
self.context_y = context_y.to(device)
6464

65-
def forward(self, candidate_x):
65+
def forward(self, candidate_x: Tensor) -> Tensor:
6666
"""
6767
Conduct the Latent Information Gain acquisition function for the inputs.
6868
6969
Args:
70-
candidate_x: Candidate input points, as a Tensor.
70+
candidate_x: Candidate input points, as a Tensor. Ideally in the shape (N, q, D), and assumes N = 1 if the given dimensions are 2D.
7171
7272
Returns:
73-
torch.Tensor: The LIG score of computed KLDs.
73+
torch.Tensor: The LIG scores of computed KLDs, in the shape (N, q).
7474
"""
75-
7675
candidate_x = candidate_x.to(device)
77-
76+
if candidate_x.dim() == 2:
77+
candidate_x = candidate_x.unsqueeze(0)
78+
N, q, D = candidate_x.shape
7879
# Encoding and Scaling the context data
7980
z_mu_context, z_logvar_context = self.model.data_to_z_params(self.context_x, self.context_y)
80-
kl = 0.0
81+
kl = torch.zeros(N, q, device=device)
8182
for _ in range(self.num_samples):
82-
# Taking reparameterized samples
83+
# Taking Samples/Predictions
8384
samples = self.model.sample_z(z_mu_context, z_logvar_context)
84-
85-
# Using the Decoder to take predicted values
86-
y_pred = self.model.decoder(candidate_x, samples)
87-
88-
# Combining context and candidate data
89-
combined_x = torch.cat([self.context_x, candidate_x], dim=0).to(device)
85+
y_pred = self.model.decoder(candidate_x.view(-1, D), samples)
86+
# Combining the data
87+
combined_x = torch.cat([self.context_x, candidate_x.view(-1, D)], dim=0).to(device)
9088
combined_y = torch.cat([self.context_y, y_pred], dim=0).to(device)
91-
9289
# Computing posterior variables
9390
z_mu_posterior, z_logvar_posterior = self.model.data_to_z_params(combined_x, combined_y)
9491
std_prior = self.min_std + self.scaler * torch.sigmoid(z_logvar_context)
9592
std_posterior = self.min_std + self.scaler * torch.sigmoid(z_logvar_posterior)
96-
9793
p = torch.distributions.Normal(z_mu_posterior, std_posterior)
9894
q = torch.distributions.Normal(z_mu_context, std_prior)
99-
10095
kl_divergence = torch.distributions.kl_divergence(p, q).sum()
10196
kl += kl_divergence
102-
103-
# Average KLD
10497
return kl / self.num_samples

0 commit comments

Comments
 (0)