Skip to content

Commit 4f35e0f

Browse files
authored
1/25 Updates
1 parent 657151f commit 4f35e0f

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

botorch_community/acquisition/latent_information_gain.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,51 +23,60 @@
2323
from typing import Optional
2424

2525
import torch
26-
from botorch import settings
26+
from botorch.acquisition import AcquisitionFunction
2727
from botorch_community.models.np_regression import NeuralProcessModel
2828
from torch import Tensor
2929

3030
import torch
3131
#reference: https://arxiv.org/abs/2106.02770
3232

33-
class LatentInformationGain:
33+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34+
35+
class LatentInformationGain(AcquisitionFunction):
3436
def __init__(
3537
self,
38+
context_x: torch.Tensor,
39+
context_y: torch.Tensor,
3640
model: NeuralProcessModel,
3741
num_samples: int = 10,
38-
min_std: float = 0.1,
39-
scaler: float = 0.9
42+
min_std: float = 0.01,
43+
scaler: float = 0.5
4044
) -> None:
4145
"""
4246
Latent Information Gain (LIG) Acquisition Function, designed for the
43-
NeuralProcessModel.
47+
NeuralProcessModel. This is a subclass of AcquisitionFunction.
4448
4549
Args:
4650
model: Trained NeuralProcessModel.
51+
context_x: Context input points, as a Tensor.
52+
context_y: Context target points, as a Tensor.
4753
num_samples (int): Number of samples for calculation, defaults to 10.
4854
min_std: Float representing the minimum possible standardized std, defaults to 0.1.
4955
scaler: Float scaling the std, defaults to 0.9.
5056
"""
51-
self.model = model
57+
super().__init__(model=model)
58+
self.model = model.to(device)
5259
self.num_samples = num_samples
5360
self.min_std = min_std
5461
self.scaler = scaler
62+
self.context_x = context_x.to(device)
63+
self.context_y = context_y.to(device)
5564

56-
def acquisition(self, candidate_x, context_x, context_y):
65+
def forward(self, candidate_x):
5766
"""
5867
Conduct the Latent Information Gain acquisition function for the inputs.
5968
6069
Args:
6170
candidate_x: Candidate input points, as a Tensor.
62-
context_x: Context input points, as a Tensor.
63-
context_y: Context target points, as a Tensor.
6471
6572
Returns:
6673
torch.Tensor: The LIG score of computed KLDs.
6774
"""
6875

76+
candidate_x = candidate_x.to(device)
77+
6978
# Encoding and Scaling the context data
70-
z_mu_context, z_logvar_context = self.model.data_to_z_params(context_x, context_y)
79+
z_mu_context, z_logvar_context = self.model.data_to_z_params(self.context_x, self.context_y)
7180
kl = 0.0
7281
for _ in range(self.num_samples):
7382
# Taking reparameterized samples
@@ -77,8 +86,8 @@ def acquisition(self, candidate_x, context_x, context_y):
7786
y_pred = self.model.decoder(candidate_x, samples)
7887

7988
# Combining context and candidate data
80-
combined_x = torch.cat([context_x, candidate_x], dim=0)
81-
combined_y = torch.cat([context_y, y_pred], dim=0)
89+
combined_x = torch.cat([self.context_x, candidate_x], dim=0).to(device)
90+
combined_y = torch.cat([self.context_y, y_pred], dim=0).to(device)
8291

8392
# Computing posterior variables
8493
z_mu_posterior, z_logvar_posterior = self.model.data_to_z_params(combined_x, combined_y)

0 commit comments

Comments
 (0)