Skip to content

Commit 204ba31

Browse files
authored
Updated LIG Parameters/Generalizability
1 parent be34f60 commit 204ba31

File tree

1 file changed

+56
-43
lines changed

1 file changed

+56
-43
lines changed

botorch_community/acquisition/latent_information_gain.py

Lines changed: 56 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,80 +18,93 @@
1818
"""
1919

2020
from __future__ import annotations
21-
22-
import warnings
23-
from typing import Optional
24-
21+
from typing import Type, Any
2522
import torch
2623
from botorch.acquisition import AcquisitionFunction
2724
from botorch_community.models.np_regression import NeuralProcessModel
2825
from torch import Tensor
29-
30-
import torch
31-
#reference: https://arxiv.org/abs/2106.02770
26+
# reference: https://arxiv.org/abs/2106.02770
3227

3328
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3429

30+
3531
class LatentInformationGain(AcquisitionFunction):
3632
def __init__(
37-
self,
38-
context_x: torch.Tensor,
39-
context_y: torch.Tensor,
40-
model: NeuralProcessModel,
33+
self,
34+
model: Type[Any] = NeuralProcessModel,
4135
num_samples: int = 10,
4236
min_std: float = 0.01,
43-
scaler: float = 0.5
37+
scaler: float = 0.5,
4438
) -> None:
4539
"""
46-
Latent Information Gain (LIG) Acquisition Function, designed for the
47-
NeuralProcessModel. This is a subclass of AcquisitionFunction.
40+
Latent Information Gain (LIG) Acquisition Function.
41+
Uses the model's built-in posterior function to generalize KL computation.
4842
4943
Args:
50-
model: Trained NeuralProcessModel.
51-
context_x: Context input points, as a Tensor.
52-
context_y: Context target points, as a Tensor.
44+
model: The model class to be used, defaults to NeuralProcessModel.
5345
num_samples (int): Number of samples for calculation, defaults to 10.
54-
min_std: Float representing the minimum possible standardized std, defaults to 0.1.
55-
scaler: Float scaling the std, defaults to 0.9.
46+
min_std: Float representing the minimum possible standardized std,
47+
defaults to 0.01.
48+
scaler: Float scaling the std, defaults to 0.5.
5649
"""
57-
super().__init__(model=model)
58-
self.model = model.to(device)
50+
super().__init__()
51+
self.model = model
5952
self.num_samples = num_samples
6053
self.min_std = min_std
6154
self.scaler = scaler
62-
self.context_x = context_x.to(device)
63-
self.context_y = context_y.to(device)
6455

6556
def forward(self, candidate_x: Tensor) -> Tensor:
6657
"""
67-
Conduct the Latent Information Gain acquisition function for the inputs.
58+
Conduct the Latent Information Gain acquisition function using the model's
59+
posterior.
6860
6961
Args:
70-
candidate_x: Candidate input points, as a Tensor. Ideally in the shape (N, q, D), and assumes N = 1 if the given dimensions are 2D.
62+
candidate_x: Candidate input points, as a Tensor. Ideally in the shape
63+
(N, q, D).
7164
7265
Returns:
7366
torch.Tensor: The LIG scores of computed KLDs, in the shape (N, q).
7467
"""
7568
candidate_x = candidate_x.to(device)
7669
if candidate_x.dim() == 2:
77-
candidate_x = candidate_x.unsqueeze(0)
70+
candidate_x = candidate_x.unsqueeze(0) # Ensure (N, q, D) format
7871
N, q, D = candidate_x.shape
79-
# Encoding and Scaling the context data
80-
z_mu_context, z_logvar_context = self.model.data_to_z_params(self.context_x, self.context_y)
72+
8173
kl = torch.zeros(N, q, device=device)
82-
for _ in range(self.num_samples):
83-
# Taking Samples/Predictions
84-
samples = self.model.sample_z(z_mu_context, z_logvar_context)
85-
y_pred = self.model.decoder(candidate_x.view(-1, D), samples)
86-
# Combining the data
87-
combined_x = torch.cat([self.context_x, candidate_x.view(-1, D)], dim=0).to(device)
88-
combined_y = torch.cat([self.context_y, y_pred], dim=0).to(device)
89-
# Computing posterior variables
90-
z_mu_posterior, z_logvar_posterior = self.model.data_to_z_params(combined_x, combined_y)
91-
std_prior = self.min_std + self.scaler * torch.sigmoid(z_logvar_context)
92-
std_posterior = self.min_std + self.scaler * torch.sigmoid(z_logvar_posterior)
93-
p = torch.distributions.Normal(z_mu_posterior, std_posterior)
94-
q = torch.distributions.Normal(z_mu_context, std_prior)
95-
kl_divergence = torch.distributions.kl_divergence(p, q).sum()
96-
kl += kl_divergence
74+
75+
if self.model is NeuralProcessModel:
76+
z_mu_context, z_logvar_context = self.model.data_to_z_params(
77+
self.context_x, self.context_y
78+
)
79+
for _ in range(self.num_samples):
80+
# Taking Samples/Predictions
81+
samples = self.model.sample_z(z_mu_context, z_logvar_context)
82+
y_pred = self.model.decoder(candidate_x.view(-1, D), samples)
83+
# Combining the data
84+
combined_x = torch.cat(
85+
[self.context_x, candidate_x.view(-1, D)], dim=0
86+
).to(device)
87+
combined_y = torch.cat([self.context_y, y_pred], dim=0).to(device)
88+
# Computing posterior variables
89+
z_mu_posterior, z_logvar_posterior = self.model.data_to_z_params(
90+
combined_x, combined_y
91+
)
92+
std_prior = self.min_std + self.scaler * torch.sigmoid(z_logvar_context)
93+
std_posterior = self.min_std + self.scaler * torch.sigmoid(
94+
z_logvar_posterior
95+
)
96+
p = torch.distributions.Normal(z_mu_posterior, std_posterior)
97+
q = torch.distributions.Normal(z_mu_context, std_prior)
98+
kl_divergence = torch.distributions.kl_divergence(p, q).sum(dim=-1)
99+
kl += kl_divergence
100+
else:
101+
for _ in range(self.num_samples):
102+
posterior_prior = self.model.posterior(self.model.train_X)
103+
posterior_candidate = self.model.posterior(candidate_x.view(-1, D))
104+
105+
kl_divergence = torch.distributions.kl_divergence(
106+
posterior_candidate.mvn, posterior_prior.mvn
107+
).sum(dim=-1)
108+
kl += kl_divergence
109+
97110
return kl / self.num_samples

0 commit comments

Comments
 (0)