Skip to content

Commit ef43cc1

Browse files
authored
Cleaned up code
1 parent 43d8c32 commit ef43cc1

File tree

2 files changed

+618
-610
lines changed

2 files changed

+618
-610
lines changed
Lines changed: 128 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -1,117 +1,128 @@
1-
#!/usr/bin/env python3
2-
# Copyright (c) Meta Platforms, Inc. and affiliates.
3-
#
4-
# This source code is licensed under the MIT license found in the
5-
# LICENSE file in the root directory of this source tree.
6-
7-
r"""
8-
Latent Information Gain Acquisition Function for Neural Process Models.
9-
10-
References:
11-
12-
.. [Wu2023arxiv]
13-
Wu, D., Niu, R., Chinazzi, M., Vespignani, A., Ma, Y.-A., & Yu, R. (2023).
14-
Deep Bayesian Active Learning for Accelerating Stochastic Simulation.
15-
arXiv preprint arXiv:2106.02770. Retrieved from https://arxiv.org/abs/2106.02770
16-
17-
Contributor: eibarolle
18-
"""
19-
20-
from __future__ import annotations
21-
22-
from typing import Any, Type
23-
24-
import torch
25-
from botorch.acquisition import AcquisitionFunction
26-
from botorch_community.models.np_regression import NeuralProcessModel
27-
from torch import Tensor
28-
# reference: https://arxiv.org/abs/2106.02770
29-
30-
31-
class LatentInformationGain(AcquisitionFunction):
32-
def __init__(
33-
self,
34-
model: Type[Any],
35-
num_samples: int = 10,
36-
min_std: float = 0.01,
37-
scaler: float = 0.5,
38-
) -> None:
39-
"""
40-
Latent Information Gain (LIG) Acquisition Function.
41-
Uses the model's built-in posterior function to generalize KL computation.
42-
43-
Args:
44-
model: The model class to be used, defaults to NeuralProcessModel.
45-
num_samples: Int showing the # of samples for calculation, defaults to 10.
46-
min_std: Float representing the minimum possible standardized std,
47-
defaults to 0.01.
48-
scaler: Float scaling the std, defaults to 0.5.
49-
"""
50-
super().__init__(model)
51-
self.model = model
52-
self.num_samples = num_samples
53-
self.min_std = min_std
54-
self.scaler = scaler
55-
56-
def forward(self, candidate_x: Tensor) -> Tensor:
57-
"""
58-
Conduct the Latent Information Gain acquisition function for the inputs.
59-
60-
Args:
61-
candidate_x: Candidate input points, as a Tensor. Ideally in the shape
62-
(N, q, D).
63-
64-
Returns:
65-
torch.Tensor: The LIG scores of computed KLDs, in the shape (N, q).
66-
"""
67-
device = candidate_x.device
68-
candidate_x = candidate_x.to(device)
69-
N, q, D = candidate_x.shape
70-
kl = torch.zeros(N, device=device, dtype=torch.float32)
71-
72-
if isinstance(self.model, NeuralProcessModel):
73-
x_c, y_c, _, _ = self.model.random_split_context_target(
74-
self.model.train_X, self.model.train_Y, self.model.n_context
75-
)
76-
self.model.z_mu_context, self.model.z_logvar_context = (
77-
self.model.data_to_z_params(x_c, y_c)
78-
)
79-
80-
for i in range(N):
81-
x_i = candidate_x[i]
82-
kl_i = 0.0
83-
84-
for _ in range(self.num_samples):
85-
sample_z = self.model.sample_z(
86-
self.model.z_mu_context, self.model.z_logvar_context
87-
)
88-
if sample_z.dim() == 1:
89-
sample_z = sample_z.unsqueeze(0)
90-
91-
y_pred = self.model.decoder(x_i, sample_z)
92-
93-
combined_x = torch.cat([x_c, x_i], dim=0)
94-
combined_y = torch.cat([y_c, y_pred], dim=0)
95-
96-
self.model.z_mu_all, self.model.z_logvar_all = (
97-
self.model.data_to_z_params(combined_x, combined_y)
98-
)
99-
kl_sample = self.model.KLD_gaussian(self.min_std, self.scaler)
100-
kl_i += kl_sample
101-
102-
kl[i] = kl_i / self.num_samples
103-
104-
else:
105-
for i in range(N):
106-
x_i = candidate_x[i]
107-
kl_i = 0.0
108-
for _ in range(self.num_samples):
109-
posterior_prior = self.model.posterior(self.model.train_X)
110-
posterior_candidate = self.model.posterior(x_i)
111-
112-
kl_i += torch.distributions.kl_divergence(
113-
posterior_candidate.mvn, posterior_prior.mvn
114-
).sum()
115-
116-
kl[i] = kl_i / self.num_samples
117-
return kl
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
r"""
8+
Latent Information Gain Acquisition Function for Neural Process Models.
9+
10+
References:
11+
12+
.. [Wu2023arxiv]
13+
Wu, D., Niu, R., Chinazzi, M., Vespignani, A., Ma, Y.-A., & Yu, R. (2023).
14+
Deep Bayesian Active Learning for Accelerating Stochastic Simulation.
15+
arXiv preprint arXiv:2106.02770. Retrieved from https://arxiv.org/abs/2106.02770
16+
17+
Contributor: eibarolle
18+
"""
19+
20+
from __future__ import annotations
21+
22+
from typing import Any, Type
23+
24+
import torch
25+
from botorch.acquisition import AcquisitionFunction
26+
from botorch_community.models.np_regression import NeuralProcessModel
27+
from torch import Tensor
28+
# reference: https://arxiv.org/abs/2106.02770
29+
30+
31+
class LatentInformationGain(AcquisitionFunction):
32+
def __init__(
33+
self,
34+
model: Type[Any],
35+
num_samples: int = 10,
36+
min_std: float = 0.01,
37+
scaler: float = 0.5,
38+
) -> None:
39+
"""
40+
Latent Information Gain (LIG) Acquisition Function.
41+
Uses the model's built-in posterior function to generalize KL computation.
42+
43+
Args:
44+
model: The model class to be used, defaults to NeuralProcessModel.
45+
num_samples: Int showing the # of samples for calculation, defaults to 10.
46+
min_std: Float representing the minimum possible standardized std,
47+
defaults to 0.01.
48+
scaler: Float scaling the std, defaults to 0.5.
49+
"""
50+
super().__init__(model)
51+
self.model = model
52+
self.num_samples = num_samples
53+
self.min_std = min_std
54+
self.scaler = scaler
55+
56+
def forward(self, candidate_x: Tensor) -> Tensor:
57+
"""
58+
Conduct the Latent Information Gain acquisition function for the inputs.
59+
60+
Args:
61+
candidate_x: Candidate input points, as a Tensor. Ideally in the shape
62+
(N, q, D).
63+
64+
Returns:
65+
torch.Tensor: The LIG scores of computed KLDs, in the shape (N, q).
66+
"""
67+
device = candidate_x.device
68+
candidate_x = candidate_x.to(device)
69+
N, q, D = candidate_x.shape
70+
kl = torch.zeros(N, device=device, dtype=torch.float32)
71+
72+
if isinstance(self.model, NeuralProcessModel):
73+
x_c, y_c, _, _ = self.model.random_split_context_target(
74+
self.model.train_X, self.model.train_Y, self.model.n_context
75+
)
76+
self.model.z_mu_context, self.model.z_logvar_context = (
77+
self.model.data_to_z_params(x_c, y_c)
78+
)
79+
80+
for i in range(N):
81+
x_i = candidate_x[i]
82+
kl_i = 0.0
83+
84+
for _ in range(self.num_samples):
85+
sample_z = self.model.sample_z(
86+
self.model.z_mu_context, self.model.z_logvar_context
87+
)
88+
if sample_z.dim() == 1:
89+
sample_z = sample_z.unsqueeze(0)
90+
91+
y_pred = self.model.decoder(x_i, sample_z)
92+
93+
combined_x = torch.cat([x_c, x_i], dim=0)
94+
combined_y = torch.cat([y_c, y_pred], dim=0)
95+
96+
self.model.z_mu_all, self.model.z_logvar_all = (
97+
self.model.data_to_z_params(combined_x, combined_y)
98+
)
99+
kl_sample = self.model.KLD_gaussian(self.min_std, self.scaler)
100+
kl_i += kl_sample
101+
102+
kl[i] = kl_i / self.num_samples
103+
104+
else:
105+
for i in range(N):
106+
x_i = candidate_x[i]
107+
kl_i = 0.0
108+
for _ in range(self.num_samples):
109+
posterior_prior = self.model.posterior(self.model.train_inputs[0])
110+
posterior_candidate = self.model.posterior(x_i)
111+
112+
mean_prior = posterior_prior.mean.mean(dim=0)
113+
cov_prior = posterior_prior.variance.mean(dim=0)
114+
mvn_prior = torch.distributions.MultivariateNormal(
115+
mean_prior, torch.diag(cov_prior)
116+
)
117+
118+
mean_candidate = posterior_candidate.mean.mean(dim=0)
119+
cov_candidate = posterior_candidate.variance.mean(dim=0)
120+
mvn_candidate = torch.distributions.MultivariateNormal(
121+
mean_candidate, torch.diag(cov_candidate)
122+
)
123+
124+
kl_i += torch.distributions.kl_divergence(mvn_candidate, mvn_prior)
125+
126+
kl[i] = kl_i / self.num_samples
127+
128+
return kl

0 commit comments

Comments
 (0)