@@ -62,43 +62,36 @@ def __init__(
62
62
self .context_x = context_x .to (device )
63
63
self .context_y = context_y .to (device )
64
64
65
- def forward (self , candidate_x ) :
65
+ def forward (self , candidate_x : Tensor ) -> Tensor :
66
66
"""
67
67
Conduct the Latent Information Gain acquisition function for the inputs.
68
68
69
69
Args:
70
- candidate_x: Candidate input points, as a Tensor.
70
+ candidate_x: Candidate input points, as a Tensor. Ideally in the shape (N, q, D), and assumes N = 1 if the given dimensions are 2D.
71
71
72
72
Returns:
73
- torch.Tensor: The LIG score of computed KLDs.
73
+ torch.Tensor: The LIG scores of computed KLDs, in the shape (N, q) .
74
74
"""
75
-
76
75
candidate_x = candidate_x .to (device )
77
-
76
+ if candidate_x .dim () == 2 :
77
+ candidate_x = candidate_x .unsqueeze (0 )
78
+ N , q , D = candidate_x .shape
78
79
# Encoding and Scaling the context data
79
80
z_mu_context , z_logvar_context = self .model .data_to_z_params (self .context_x , self .context_y )
80
- kl = 0.0
81
+ kl = torch . zeros ( N , q , device = device )
81
82
for _ in range (self .num_samples ):
82
- # Taking reparameterized samples
83
+ # Taking Samples/Predictions
83
84
samples = self .model .sample_z (z_mu_context , z_logvar_context )
84
-
85
- # Using the Decoder to take predicted values
86
- y_pred = self .model .decoder (candidate_x , samples )
87
-
88
- # Combining context and candidate data
89
- combined_x = torch .cat ([self .context_x , candidate_x ], dim = 0 ).to (device )
85
+ y_pred = self .model .decoder (candidate_x .view (- 1 , D ), samples )
86
+ # Combining the data
87
+ combined_x = torch .cat ([self .context_x , candidate_x .view (- 1 , D )], dim = 0 ).to (device )
90
88
combined_y = torch .cat ([self .context_y , y_pred ], dim = 0 ).to (device )
91
-
92
89
# Computing posterior variables
93
90
z_mu_posterior , z_logvar_posterior = self .model .data_to_z_params (combined_x , combined_y )
94
91
std_prior = self .min_std + self .scaler * torch .sigmoid (z_logvar_context )
95
92
std_posterior = self .min_std + self .scaler * torch .sigmoid (z_logvar_posterior )
96
-
97
93
p = torch .distributions .Normal (z_mu_posterior , std_posterior )
98
94
q = torch .distributions .Normal (z_mu_context , std_prior )
99
-
100
95
kl_divergence = torch .distributions .kl_divergence (p , q ).sum ()
101
96
kl += kl_divergence
102
-
103
- # Average KLD
104
97
return kl / self .num_samples
0 commit comments