@@ -47,7 +47,7 @@ def __init__(
47
47
defaults to 0.01.
48
48
scaler: Float scaling the std, defaults to 0.5.
49
49
"""
50
- super ().__init__ ()
50
+ super ().__init__ (model )
51
51
self .model = model
52
52
self .num_samples = num_samples
53
53
self .min_std = min_std
@@ -72,19 +72,28 @@ def forward(self, candidate_x: Tensor) -> Tensor:
72
72
73
73
kl = torch .zeros (N , q , device = device )
74
74
75
- if self .model is NeuralProcessModel :
76
- z_mu_context , z_logvar_context = self .model .data_to_z_params (
77
- self .context_x , self .context_y
75
+ if isinstance ( self .model , NeuralProcessModel ) :
76
+ x_c , y_c , x_t , y_t = self .model .random_split_context_target (
77
+ self .model . train_X [:, 0 ], self .model . train_Y
78
78
)
79
+ print (x_c .shape )
80
+ print (y_c .shape )
81
+ print (self .model .train_X )
82
+ print (self .model .train_X [:, 0 ])
83
+ print (self .model .train_Y )
84
+ print (self .model .train_Y [:, 0 ])
85
+ z_mu_context , z_logvar_context = self .model .data_to_z_params (x_c , y_c , xy_dim = - 1 )
86
+ print (z_mu_context )
87
+ print (z_logvar_context )
79
88
for _ in range (self .num_samples ):
80
89
# Taking Samples/Predictions
81
90
samples = self .model .sample_z (z_mu_context , z_logvar_context )
82
91
y_pred = self .model .decoder (candidate_x .view (- 1 , D ), samples )
83
92
# Combining the data
84
93
combined_x = torch .cat (
85
- [self . context_x , candidate_x .view (- 1 , D )], dim = 0
94
+ [x_c , candidate_x .view (- 1 , D )], dim = 0
86
95
).to (device )
87
- combined_y = torch .cat ([self .context_y , y_pred ], dim = 0 ).to (device )
96
+ combined_y = torch .cat ([self .y_c , y_pred ], dim = 0 ).to (device )
88
97
# Computing posterior variables
89
98
z_mu_posterior , z_logvar_posterior = self .model .data_to_z_params (
90
99
combined_x , combined_y
0 commit comments