18
18
"""
19
19
20
20
from __future__ import annotations
21
- from typing import Type , Any
21
+
22
+ from typing import Any , Type
23
+
22
24
import torch
23
25
from botorch .acquisition import AcquisitionFunction
24
26
from botorch_community .models .np_regression import NeuralProcessModel
@@ -42,7 +44,7 @@ def __init__(
42
44
43
45
Args:
44
46
model: The model class to be used, defaults to NeuralProcessModel.
45
- num_samples (int): Number of samples for calculation, defaults to 10.
47
+ num_samples: Int showing the # of samples for calculation, defaults to 10.
46
48
min_std: Float representing the minimum possible standardized std,
47
49
defaults to 0.01.
48
50
scaler: Float scaling the std, defaults to 0.5.
@@ -74,26 +76,18 @@ def forward(self, candidate_x: Tensor) -> Tensor:
74
76
75
77
if isinstance (self .model , NeuralProcessModel ):
76
78
x_c , y_c , x_t , y_t = self .model .random_split_context_target (
77
- self .model .train_X [:, 0 ], self .model .train_Y
79
+ self .model .train_X ,
80
+ self .model .train_Y ,
81
+ self .model .n_context
78
82
)
79
- print (x_c .shape )
80
- print (y_c .shape )
81
- print (self .model .train_X )
82
- print (self .model .train_X [:, 0 ])
83
- print (self .model .train_Y )
84
- print (self .model .train_Y [:, 0 ])
85
- z_mu_context , z_logvar_context = self .model .data_to_z_params (x_c , y_c , xy_dim = - 1 )
86
- print (z_mu_context )
87
- print (z_logvar_context )
83
+ z_mu_context , z_logvar_context = self .model .data_to_z_params (x_c , y_c )
88
84
for _ in range (self .num_samples ):
89
85
# Taking Samples/Predictions
90
86
samples = self .model .sample_z (z_mu_context , z_logvar_context )
91
87
y_pred = self .model .decoder (candidate_x .view (- 1 , D ), samples )
92
88
# Combining the data
93
- combined_x = torch .cat (
94
- [x_c , candidate_x .view (- 1 , D )], dim = 0
95
- ).to (device )
96
- combined_y = torch .cat ([self .y_c , y_pred ], dim = 0 ).to (device )
89
+ combined_x = torch .cat ([x_c , candidate_x .view (- 1 , D )], dim = 0 ).to (device )
90
+ combined_y = torch .cat ([y_c , y_pred ], dim = 0 ).to (device )
97
91
# Computing posterior variables
98
92
z_mu_posterior , z_logvar_posterior = self .model .data_to_z_params (
99
93
combined_x , combined_y
0 commit comments