|
27 | 27 | from torch import Tensor
|
28 | 28 | # reference: https://arxiv.org/abs/2106.02770
|
29 | 29 |
|
30 |
| -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
31 |
| - |
32 | 30 |
|
33 | 31 | class LatentInformationGain(AcquisitionFunction):
|
34 | 32 | def __init__(
|
@@ -56,58 +54,58 @@ def __init__(
|
56 | 54 | self.scaler = scaler
|
57 | 55 |
|
58 | 56 | def forward(self, candidate_x: Tensor) -> Tensor:
|
59 |
| - """ |
60 |
| - Conduct the Latent Information Gain acquisition function using the model's |
61 |
| - posterior. |
62 |
| -
|
63 |
| - Args: |
64 |
| - candidate_x: Candidate input points, as a Tensor. Ideally in the shape |
65 |
| - (N, q, D). |
66 |
| -
|
67 |
| - Returns: |
68 |
| - torch.Tensor: The LIG scores of computed KLDs, in the shape (N, q). |
69 |
| - """ |
| 57 | + device = candidate_x.device |
70 | 58 | candidate_x = candidate_x.to(device)
|
71 |
| - if candidate_x.dim() == 2: |
72 |
| - candidate_x = candidate_x.unsqueeze(0) # Ensure (N, q, D) format |
73 | 59 | N, q, D = candidate_x.shape
|
74 |
| - |
75 |
| - kl = torch.zeros(N, q, device=device) |
76 |
| - |
| 60 | + kl = torch.zeros(N, device=device) |
77 | 61 | if isinstance(self.model, NeuralProcessModel):
|
78 |
| - x_c, y_c, x_t, y_t = self.model.random_split_context_target( |
79 |
| - self.model.train_X, |
80 |
| - self.model.train_Y, |
81 |
| - self.model.n_context |
| 62 | + x_c, y_c, _, _ = self.model.random_split_context_target( |
| 63 | + self.model.train_X, self.model.train_Y, self.model.n_context |
82 | 64 | )
|
83 | 65 | z_mu_context, z_logvar_context = self.model.data_to_z_params(x_c, y_c)
|
84 |
| - for _ in range(self.num_samples): |
85 |
| - # Taking Samples/Predictions |
86 |
| - samples = self.model.sample_z(z_mu_context, z_logvar_context) |
87 |
| - y_pred = self.model.decoder(candidate_x.view(-1, D), samples) |
88 |
| - # Combining the data |
89 |
| - combined_x = torch.cat([x_c, candidate_x.view(-1, D)], dim=0).to(device) |
90 |
| - combined_y = torch.cat([y_c, y_pred], dim=0).to(device) |
91 |
| - # Computing posterior variables |
92 |
| - z_mu_posterior, z_logvar_posterior = self.model.data_to_z_params( |
93 |
| - combined_x, combined_y |
94 |
| - ) |
95 |
| - std_prior = self.min_std + self.scaler * torch.sigmoid(z_logvar_context) |
96 |
| - std_posterior = self.min_std + self.scaler * torch.sigmoid( |
97 |
| - z_logvar_posterior |
98 |
| - ) |
99 |
| - p = torch.distributions.Normal(z_mu_posterior, std_posterior) |
100 |
| - q = torch.distributions.Normal(z_mu_context, std_prior) |
101 |
| - kl_divergence = torch.distributions.kl_divergence(p, q).sum(dim=-1) |
102 |
| - kl += kl_divergence |
103 |
| - else: |
104 |
| - for _ in range(self.num_samples): |
105 |
| - posterior_prior = self.model.posterior(self.model.train_X) |
106 |
| - posterior_candidate = self.model.posterior(candidate_x.view(-1, D)) |
107 | 66 |
|
108 |
| - kl_divergence = torch.distributions.kl_divergence( |
109 |
| - posterior_candidate.mvn, posterior_prior.mvn |
110 |
| - ).sum(dim=-1) |
111 |
| - kl += kl_divergence |
| 67 | + for i in range(N): |
| 68 | + x_i = candidate_x[i] |
| 69 | + kl_i = 0.0 |
112 | 70 |
|
113 |
| - return kl / self.num_samples |
| 71 | + for _ in range(self.num_samples): |
| 72 | + sample_z = self.model.sample_z(z_mu_context, z_logvar_context) |
| 73 | + if sample_z.dim() == 1: |
| 74 | + sample_z = sample_z.unsqueeze(0) |
| 75 | + |
| 76 | + y_pred = self.model.decoder(x_i, sample_z) |
| 77 | + |
| 78 | + combined_x = torch.cat([x_c, x_i], dim=0) |
| 79 | + combined_y = torch.cat([y_c, y_pred], dim=0) |
| 80 | + |
| 81 | + z_mu_post, z_logvar_post = self.model.data_to_z_params( |
| 82 | + combined_x, combined_y |
| 83 | + ) |
| 84 | + |
| 85 | + std_prior = self.min_std + self.scaler * torch.sigmoid( |
| 86 | + z_logvar_context |
| 87 | + ) |
| 88 | + std_post = self.min_std + self.scaler * torch.sigmoid(z_logvar_post) |
| 89 | + |
| 90 | + p = torch.distributions.Normal(z_mu_post, std_post) |
| 91 | + q = torch.distributions.Normal(z_mu_context, std_prior) |
| 92 | + kl_sample = torch.distributions.kl_divergence(p, q).sum() |
| 93 | + |
| 94 | + kl_i += kl_sample |
| 95 | + |
| 96 | + kl[i] = kl_i / self.num_samples |
| 97 | + |
| 98 | + else: |
| 99 | + for i in range(N): |
| 100 | + x_i = candidate_x[i] |
| 101 | + kl_i = 0.0 |
| 102 | + for _ in range(self.num_samples): |
| 103 | + posterior_prior = self.model.posterior(self.model.train_X) |
| 104 | + posterior_candidate = self.model.posterior(x_i) |
| 105 | + |
| 106 | + kl_i += torch.distributions.kl_divergence( |
| 107 | + posterior_candidate.mvn, posterior_prior.mvn |
| 108 | + ).sum() |
| 109 | + |
| 110 | + kl[i] = kl_i / self.num_samples |
| 111 | + return kl |
0 commit comments