|
18 | 18 | """
|
19 | 19 |
|
20 | 20 | from __future__ import annotations
|
21 |
| - |
22 |
| -import warnings |
23 |
| -from typing import Optional |
24 |
| - |
| 21 | +from typing import Type, Any |
25 | 22 | import torch
|
26 | 23 | from botorch.acquisition import AcquisitionFunction
|
27 | 24 | from botorch_community.models.np_regression import NeuralProcessModel
|
28 | 25 | from torch import Tensor
|
29 |
| - |
30 |
| -import torch |
31 |
| -#reference: https://arxiv.org/abs/2106.02770 |
| 26 | +# reference: https://arxiv.org/abs/2106.02770 |
32 | 27 |
|
33 | 28 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
34 | 29 |
|
| 30 | + |
35 | 31 | class LatentInformationGain(AcquisitionFunction):
|
36 | 32 | def __init__(
|
37 |
| - self, |
38 |
| - context_x: torch.Tensor, |
39 |
| - context_y: torch.Tensor, |
40 |
| - model: NeuralProcessModel, |
| 33 | + self, |
| 34 | + model: Type[Any] = NeuralProcessModel, |
41 | 35 | num_samples: int = 10,
|
42 | 36 | min_std: float = 0.01,
|
43 |
| - scaler: float = 0.5 |
| 37 | + scaler: float = 0.5, |
44 | 38 | ) -> None:
|
45 | 39 | """
|
46 |
| - Latent Information Gain (LIG) Acquisition Function, designed for the |
47 |
| - NeuralProcessModel. This is a subclass of AcquisitionFunction. |
| 40 | + Latent Information Gain (LIG) Acquisition Function. |
| 41 | + Uses the model's built-in posterior function to generalize KL computation. |
48 | 42 |
|
49 | 43 | Args:
|
50 |
| - model: Trained NeuralProcessModel. |
51 |
| - context_x: Context input points, as a Tensor. |
52 |
| - context_y: Context target points, as a Tensor. |
| 44 | + model: The model class to be used, defaults to NeuralProcessModel. |
53 | 45 | num_samples (int): Number of samples for calculation, defaults to 10.
|
54 |
| - min_std: Float representing the minimum possible standardized std, defaults to 0.1. |
55 |
| - scaler: Float scaling the std, defaults to 0.9. |
| 46 | + min_std: Float representing the minimum possible standardized std, |
| 47 | + defaults to 0.01. |
| 48 | + scaler: Float scaling the std, defaults to 0.5. |
56 | 49 | """
|
57 |
| - super().__init__(model=model) |
58 |
| - self.model = model.to(device) |
| 50 | + super().__init__() |
| 51 | + self.model = model |
59 | 52 | self.num_samples = num_samples
|
60 | 53 | self.min_std = min_std
|
61 | 54 | self.scaler = scaler
|
62 |
| - self.context_x = context_x.to(device) |
63 |
| - self.context_y = context_y.to(device) |
64 | 55 |
|
65 | 56 | def forward(self, candidate_x: Tensor) -> Tensor:
|
66 | 57 | """
|
67 |
| - Conduct the Latent Information Gain acquisition function for the inputs. |
| 58 | + Conduct the Latent Information Gain acquisition function using the model's |
| 59 | + posterior. |
68 | 60 |
|
69 | 61 | Args:
|
70 |
| - candidate_x: Candidate input points, as a Tensor. Ideally in the shape (N, q, D), and assumes N = 1 if the given dimensions are 2D. |
| 62 | + candidate_x: Candidate input points, as a Tensor. Ideally in the shape |
| 63 | + (N, q, D). |
71 | 64 |
|
72 | 65 | Returns:
|
73 | 66 | torch.Tensor: The LIG scores of computed KLDs, in the shape (N, q).
|
74 | 67 | """
|
75 | 68 | candidate_x = candidate_x.to(device)
|
76 | 69 | if candidate_x.dim() == 2:
|
77 |
| - candidate_x = candidate_x.unsqueeze(0) |
| 70 | + candidate_x = candidate_x.unsqueeze(0) # Ensure (N, q, D) format |
78 | 71 | N, q, D = candidate_x.shape
|
79 |
| - # Encoding and Scaling the context data |
80 |
| - z_mu_context, z_logvar_context = self.model.data_to_z_params(self.context_x, self.context_y) |
| 72 | + |
81 | 73 | kl = torch.zeros(N, q, device=device)
|
82 |
| - for _ in range(self.num_samples): |
83 |
| - # Taking Samples/Predictions |
84 |
| - samples = self.model.sample_z(z_mu_context, z_logvar_context) |
85 |
| - y_pred = self.model.decoder(candidate_x.view(-1, D), samples) |
86 |
| - # Combining the data |
87 |
| - combined_x = torch.cat([self.context_x, candidate_x.view(-1, D)], dim=0).to(device) |
88 |
| - combined_y = torch.cat([self.context_y, y_pred], dim=0).to(device) |
89 |
| - # Computing posterior variables |
90 |
| - z_mu_posterior, z_logvar_posterior = self.model.data_to_z_params(combined_x, combined_y) |
91 |
| - std_prior = self.min_std + self.scaler * torch.sigmoid(z_logvar_context) |
92 |
| - std_posterior = self.min_std + self.scaler * torch.sigmoid(z_logvar_posterior) |
93 |
| - p = torch.distributions.Normal(z_mu_posterior, std_posterior) |
94 |
| - q = torch.distributions.Normal(z_mu_context, std_prior) |
95 |
| - kl_divergence = torch.distributions.kl_divergence(p, q).sum() |
96 |
| - kl += kl_divergence |
| 74 | + |
| 75 | + if self.model is NeuralProcessModel: |
| 76 | + z_mu_context, z_logvar_context = self.model.data_to_z_params( |
| 77 | + self.context_x, self.context_y |
| 78 | + ) |
| 79 | + for _ in range(self.num_samples): |
| 80 | + # Taking Samples/Predictions |
| 81 | + samples = self.model.sample_z(z_mu_context, z_logvar_context) |
| 82 | + y_pred = self.model.decoder(candidate_x.view(-1, D), samples) |
| 83 | + # Combining the data |
| 84 | + combined_x = torch.cat( |
| 85 | + [self.context_x, candidate_x.view(-1, D)], dim=0 |
| 86 | + ).to(device) |
| 87 | + combined_y = torch.cat([self.context_y, y_pred], dim=0).to(device) |
| 88 | + # Computing posterior variables |
| 89 | + z_mu_posterior, z_logvar_posterior = self.model.data_to_z_params( |
| 90 | + combined_x, combined_y |
| 91 | + ) |
| 92 | + std_prior = self.min_std + self.scaler * torch.sigmoid(z_logvar_context) |
| 93 | + std_posterior = self.min_std + self.scaler * torch.sigmoid( |
| 94 | + z_logvar_posterior |
| 95 | + ) |
| 96 | + p = torch.distributions.Normal(z_mu_posterior, std_posterior) |
| 97 | + q = torch.distributions.Normal(z_mu_context, std_prior) |
| 98 | + kl_divergence = torch.distributions.kl_divergence(p, q).sum(dim=-1) |
| 99 | + kl += kl_divergence |
| 100 | + else: |
| 101 | + for _ in range(self.num_samples): |
| 102 | + posterior_prior = self.model.posterior(self.model.train_X) |
| 103 | + posterior_candidate = self.model.posterior(candidate_x.view(-1, D)) |
| 104 | + |
| 105 | + kl_divergence = torch.distributions.kl_divergence( |
| 106 | + posterior_candidate.mvn, posterior_prior.mvn |
| 107 | + ).sum(dim=-1) |
| 108 | + kl += kl_divergence |
| 109 | + |
97 | 110 | return kl / self.num_samples
|
0 commit comments