Skip to content

Commit bf95d41

Browse files
authored
LIG Updated Parameters
1 parent e13f38c commit bf95d41

File tree

1 file changed

+10
-16
lines changed

1 file changed

+10
-16
lines changed

botorch_community/acquisition/latent_information_gain.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
"""
1919

2020
from __future__ import annotations
21-
from typing import Type, Any
21+
22+
from typing import Any, Type
23+
2224
import torch
2325
from botorch.acquisition import AcquisitionFunction
2426
from botorch_community.models.np_regression import NeuralProcessModel
@@ -42,7 +44,7 @@ def __init__(
4244
4345
Args:
4446
model: The model class to be used, defaults to NeuralProcessModel.
45-
num_samples (int): Number of samples for calculation, defaults to 10.
47+
num_samples: Int showing the # of samples for calculation, defaults to 10.
4648
min_std: Float representing the minimum possible standardized std,
4749
defaults to 0.01.
4850
scaler: Float scaling the std, defaults to 0.5.
@@ -74,26 +76,18 @@ def forward(self, candidate_x: Tensor) -> Tensor:
7476

7577
if isinstance(self.model, NeuralProcessModel):
7678
x_c, y_c, x_t, y_t = self.model.random_split_context_target(
77-
self.model.train_X[:, 0], self.model.train_Y
79+
self.model.train_X,
80+
self.model.train_Y,
81+
self.model.n_context
7882
)
79-
print(x_c.shape)
80-
print(y_c.shape)
81-
print(self.model.train_X)
82-
print(self.model.train_X[:, 0])
83-
print(self.model.train_Y)
84-
print(self.model.train_Y[:, 0])
85-
z_mu_context, z_logvar_context = self.model.data_to_z_params(x_c, y_c, xy_dim = -1)
86-
print(z_mu_context)
87-
print(z_logvar_context)
83+
z_mu_context, z_logvar_context = self.model.data_to_z_params(x_c, y_c)
8884
for _ in range(self.num_samples):
8985
# Taking Samples/Predictions
9086
samples = self.model.sample_z(z_mu_context, z_logvar_context)
9187
y_pred = self.model.decoder(candidate_x.view(-1, D), samples)
9288
# Combining the data
93-
combined_x = torch.cat(
94-
[x_c, candidate_x.view(-1, D)], dim=0
95-
).to(device)
96-
combined_y = torch.cat([self.y_c, y_pred], dim=0).to(device)
89+
combined_x = torch.cat([x_c, candidate_x.view(-1, D)], dim=0).to(device)
90+
combined_y = torch.cat([y_c, y_pred], dim=0).to(device)
9791
# Computing posterior variables
9892
z_mu_posterior, z_logvar_posterior = self.model.data_to_z_params(
9993
combined_x, combined_y

0 commit comments

Comments
 (0)