Skip to content

Commit 9304881

Browse files
authored
Latent Information Gain
1 parent 7ca0b2e commit 9304881

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
r"""
8+
Latent Information Gain Acquisition Function for Neural Process Models.
9+
10+
References:
11+
12+
.. [Wu2023arxiv]
13+
Wu, D., Niu, R., Chinazzi, M., Vespignani, A., Ma, Y.-A., & Yu, R. (2023).
14+
Deep Bayesian Active Learning for Accelerating Stochastic Simulation.
15+
arXiv preprint arXiv:2106.02770. Retrieved from https://arxiv.org/abs/2106.02770
16+
17+
Contributor: eibarolle
18+
"""
19+
20+
from __future__ import annotations
21+
22+
import warnings
23+
from typing import Optional
24+
25+
import torch
26+
from botorch import settings
27+
from botorch_community.models.np_regression import NeuralProcessModel
28+
from torch import Tensor
29+
30+
import torch
31+
#reference: https://arxiv.org/abs/2106.02770
32+
33+
class LatentInformationGain:
34+
def __init__(
35+
self,
36+
model: NeuralProcessModel,
37+
num_samples: int = 10,
38+
min_std: float = 0.1,
39+
scaler: float = 0.9
40+
) -> None:
41+
"""
42+
Latent Information Gain (LIG) Acquisition Function, designed for the
43+
NeuralProcessModel.
44+
45+
Args:
46+
model: Trained NeuralProcessModel.
47+
num_samples (int): Number of samples for calculation, defaults to 10.
48+
min_std: Float representing the minimum possible standardized std, defaults to 0.1.
49+
scaler: Float scaling the std, defaults to 0.9.
50+
"""
51+
self.model = model
52+
self.num_samples = num_samples
53+
self.min_std = min_std
54+
self.scaler = scaler
55+
56+
def acquisition(self, candidate_x, context_x, context_y):
57+
"""
58+
Conduct the Latent Information Gain acquisition function for the inputs.
59+
60+
Args:
61+
candidate_x: Candidate input points, as a Tensor.
62+
context_x: Context input points, as a Tensor.
63+
context_y: Context target points, as a Tensor.
64+
65+
Returns:
66+
torch.Tensor: The LIG score of computed KLDs.
67+
"""
68+
69+
# Encoding and Scaling the context data
70+
z_mu_context, z_logvar_context = self.model.data_to_z_params(context_x, context_y)
71+
kl = 0.0
72+
for _ in range(self.num_samples):
73+
# Taking reparameterized samples
74+
samples = self.model.sample_z(z_mu_context, z_logvar_context)
75+
76+
# Using the Decoder to take predicted values
77+
y_pred = self.model.decoder(candidate_x, samples)
78+
79+
# Combining context and candidate data
80+
combined_x = torch.cat([context_x, candidate_x], dim=0)
81+
combined_y = torch.cat([context_y, y_pred], dim=0)
82+
83+
# Computing posterior variables
84+
z_mu_posterior, z_logvar_posterior = self.model.data_to_z_params(combined_x, combined_y)
85+
std_prior = self.min_std + self.scaler * torch.sigmoid(z_logvar_context)
86+
std_posterior = self.min_std + self.scaler * torch.sigmoid(z_logvar_posterior)
87+
88+
p = torch.distributions.Normal(z_mu_posterior, std_posterior)
89+
q = torch.distributions.Normal(z_mu_context, std_prior)
90+
91+
kl_divergence = torch.distributions.kl_divergence(p, q).sum()
92+
kl += kl_divergence
93+
94+
# Average KLD
95+
return kl / self.num_samples

0 commit comments

Comments
 (0)