@@ -54,10 +54,35 @@ def __init__(
54
54
self .scaler = scaler
55
55
56
56
def forward (self , candidate_x : Tensor ) -> Tensor :
57
+ """
58
+ Conduct the Latent Information Gain acquisition function for the inputs.
59
+
60
+ Args:
61
+ candidate_x: Candidate input points, as a Tensor. Ideally in the shape
62
+ (N, q, D).
63
+
64
+ Returns:
65
+ torch.Tensor: The LIG scores of computed KLDs, in the shape (N, q).
66
+ """
57
67
device = candidate_x .device
58
68
candidate_x = candidate_x .to (device )
59
69
N , q , D = candidate_x .shape
60
- kl = torch .zeros (N , device = device )
70
+ kl = torch .zeros (N , device = device , dtype = torch .float32 )
71
+ def normal_dist (mu , logvar , min_std , scaler ):
72
+ r"""Helper function for creating the normal distributions.
73
+
74
+ Args:
75
+ mu: Tensor representing the Gaussian distribution mean.
76
+ logvar: Tensor representing the log variance of the
77
+ Gaussian distribution.
78
+ min_std: Float representing the minimum standardized std.
79
+ scaler: Float scaling the std.
80
+
81
+ Returns:
82
+ torch.distributions.Normal: The normal distribution.
83
+ """
84
+ std = min_std + scaler * torch .sigmoid (logvar )
85
+ return torch .distributions .Normal (mu , std )
61
86
if isinstance (self .model , NeuralProcessModel ):
62
87
x_c , y_c , _ , _ = self .model .random_split_context_target (
63
88
self .model .train_X , self .model .train_Y , self .model .n_context
@@ -82,15 +107,11 @@ def forward(self, candidate_x: Tensor) -> Tensor:
82
107
combined_x , combined_y
83
108
)
84
109
85
- std_prior = self .min_std + self .scaler * torch .sigmoid (
86
- z_logvar_context
110
+ p = normal_dist (z_mu_post , z_logvar_post , self .min_std , self .scaler )
111
+ q = normal_dist (
112
+ z_mu_context , z_logvar_context , self .min_std , self .scaler
87
113
)
88
- std_post = self .min_std + self .scaler * torch .sigmoid (z_logvar_post )
89
-
90
- p = torch .distributions .Normal (z_mu_post , std_post )
91
- q = torch .distributions .Normal (z_mu_context , std_prior )
92
114
kl_sample = torch .distributions .kl_divergence (p , q ).sum ()
93
-
94
115
kl_i += kl_sample
95
116
96
117
kl [i ] = kl_i / self .num_samples
@@ -108,4 +129,4 @@ def forward(self, candidate_x: Tensor) -> Tensor:
108
129
).sum ()
109
130
110
131
kl [i ] = kl_i / self .num_samples
111
- return kl
132
+ return kl
0 commit comments