@@ -68,33 +68,23 @@ def forward(self, candidate_x: Tensor) -> Tensor:
68
68
candidate_x = candidate_x .to (device )
69
69
N , q , D = candidate_x .shape
70
70
kl = torch .zeros (N , device = device , dtype = torch .float32 )
71
- def normal_dist (mu , logvar , min_std , scaler ):
72
- r"""Helper function for creating the normal distributions.
73
-
74
- Args:
75
- mu: Tensor representing the Gaussian distribution mean.
76
- logvar: Tensor representing the log variance of the
77
- Gaussian distribution.
78
- min_std: Float representing the minimum standardized std.
79
- scaler: Float scaling the std.
80
-
81
- Returns:
82
- torch.distributions.Normal: The normal distribution.
83
- """
84
- std = min_std + scaler * torch .sigmoid (logvar )
85
- return torch .distributions .Normal (mu , std )
71
+
86
72
if isinstance (self .model , NeuralProcessModel ):
87
73
x_c , y_c , _ , _ = self .model .random_split_context_target (
88
74
self .model .train_X , self .model .train_Y , self .model .n_context
89
75
)
90
- z_mu_context , z_logvar_context = self .model .data_to_z_params (x_c , y_c )
76
+ self .model .z_mu_context , self .model .z_logvar_context = (
77
+ self .model .data_to_z_params (x_c , y_c )
78
+ )
91
79
92
80
for i in range (N ):
93
81
x_i = candidate_x [i ]
94
82
kl_i = 0.0
95
83
96
84
for _ in range (self .num_samples ):
97
- sample_z = self .model .sample_z (z_mu_context , z_logvar_context )
85
+ sample_z = self .model .sample_z (
86
+ self .model .z_mu_context , self .model .z_logvar_context
87
+ )
98
88
if sample_z .dim () == 1 :
99
89
sample_z = sample_z .unsqueeze (0 )
100
90
@@ -103,15 +93,10 @@ def normal_dist(mu, logvar, min_std, scaler):
103
93
combined_x = torch .cat ([x_c , x_i ], dim = 0 )
104
94
combined_y = torch .cat ([y_c , y_pred ], dim = 0 )
105
95
106
- z_mu_post , z_logvar_post = self .model .data_to_z_params (
107
- combined_x , combined_y
108
- )
109
-
110
- p = normal_dist (z_mu_post , z_logvar_post , self .min_std , self .scaler )
111
- q = normal_dist (
112
- z_mu_context , z_logvar_context , self .min_std , self .scaler
96
+ self .model .z_mu_all , self .model .z_logvar_all = (
97
+ self .model .data_to_z_params (combined_x , combined_y )
113
98
)
114
- kl_sample = torch . distributions . kl_divergence ( p , q ). sum ( )
99
+ kl_sample = self . model . KLD_gaussian ( self . min_std , self . scaler )
115
100
kl_i += kl_sample
116
101
117
102
kl [i ] = kl_i / self .num_samples
@@ -129,4 +114,4 @@ def normal_dist(mu, logvar, min_std, scaler):
129
114
).sum ()
130
115
131
116
kl [i ] = kl_i / self .num_samples
132
- return kl
117
+ return kl
0 commit comments