Skip to content

Commit d78d262

Browse files
authored
LIG WIP
1 parent 7232725 commit d78d262

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

botorch_community/acquisition/latent_information_gain.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(
4747
defaults to 0.01.
4848
scaler: Float scaling the std, defaults to 0.5.
4949
"""
50-
super().__init__()
50+
super().__init__(model)
5151
self.model = model
5252
self.num_samples = num_samples
5353
self.min_std = min_std
@@ -72,19 +72,28 @@ def forward(self, candidate_x: Tensor) -> Tensor:
7272

7373
kl = torch.zeros(N, q, device=device)
7474

75-
if self.model is NeuralProcessModel:
76-
z_mu_context, z_logvar_context = self.model.data_to_z_params(
77-
self.context_x, self.context_y
75+
if isinstance(self.model, NeuralProcessModel):
76+
x_c, y_c, x_t, y_t = self.model.random_split_context_target(
77+
self.model.train_X[:, 0], self.model.train_Y
7878
)
79+
print(x_c.shape)
80+
print(y_c.shape)
81+
print(self.model.train_X)
82+
print(self.model.train_X[:, 0])
83+
print(self.model.train_Y)
84+
print(self.model.train_Y[:, 0])
85+
z_mu_context, z_logvar_context = self.model.data_to_z_params(x_c, y_c, xy_dim = -1)
86+
print(z_mu_context)
87+
print(z_logvar_context)
7988
for _ in range(self.num_samples):
8089
# Taking Samples/Predictions
8190
samples = self.model.sample_z(z_mu_context, z_logvar_context)
8291
y_pred = self.model.decoder(candidate_x.view(-1, D), samples)
8392
# Combining the data
8493
combined_x = torch.cat(
85-
[self.context_x, candidate_x.view(-1, D)], dim=0
94+
[x_c, candidate_x.view(-1, D)], dim=0
8695
).to(device)
87-
combined_y = torch.cat([self.context_y, y_pred], dim=0).to(device)
96+
combined_y = torch.cat([self.y_c, y_pred], dim=0).to(device)
8897
# Computing posterior variables
8998
z_mu_posterior, z_logvar_posterior = self.model.data_to_z_params(
9099
combined_x, combined_y

0 commit comments

Comments
 (0)