23
23
from typing import Optional
24
24
25
25
import torch
26
- from botorch import settings
26
+ from botorch . acquisition import AcquisitionFunction
27
27
from botorch_community .models .np_regression import NeuralProcessModel
28
28
from torch import Tensor
29
29
30
30
import torch
31
31
#reference: https://arxiv.org/abs/2106.02770
32
32
33
- class LatentInformationGain :
33
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
34
+
35
+ class LatentInformationGain (AcquisitionFunction ):
34
36
def __init__ (
35
37
self ,
38
+ context_x : torch .Tensor ,
39
+ context_y : torch .Tensor ,
36
40
model : NeuralProcessModel ,
37
41
num_samples : int = 10 ,
38
- min_std : float = 0.1 ,
39
- scaler : float = 0.9
42
+ min_std : float = 0.01 ,
43
+ scaler : float = 0.5
40
44
) -> None :
41
45
"""
42
46
Latent Information Gain (LIG) Acquisition Function, designed for the
43
- NeuralProcessModel.
47
+ NeuralProcessModel. This is a subclass of AcquisitionFunction.
44
48
45
49
Args:
46
50
model: Trained NeuralProcessModel.
51
+ context_x: Context input points, as a Tensor.
52
+ context_y: Context target points, as a Tensor.
47
53
num_samples (int): Number of samples for calculation, defaults to 10.
48
54
min_std: Float representing the minimum possible standardized std, defaults to 0.1.
49
55
scaler: Float scaling the std, defaults to 0.9.
50
56
"""
51
- self .model = model
57
+ super ().__init__ (model = model )
58
+ self .model = model .to (device )
52
59
self .num_samples = num_samples
53
60
self .min_std = min_std
54
61
self .scaler = scaler
62
+ self .context_x = context_x .to (device )
63
+ self .context_y = context_y .to (device )
55
64
56
- def acquisition (self , candidate_x , context_x , context_y ):
65
+ def forward (self , candidate_x ):
57
66
"""
58
67
Conduct the Latent Information Gain acquisition function for the inputs.
59
68
60
69
Args:
61
70
candidate_x: Candidate input points, as a Tensor.
62
- context_x: Context input points, as a Tensor.
63
- context_y: Context target points, as a Tensor.
64
71
65
72
Returns:
66
73
torch.Tensor: The LIG score of computed KLDs.
67
74
"""
68
75
76
+ candidate_x = candidate_x .to (device )
77
+
69
78
# Encoding and Scaling the context data
70
- z_mu_context , z_logvar_context = self .model .data_to_z_params (context_x , context_y )
79
+ z_mu_context , z_logvar_context = self .model .data_to_z_params (self . context_x , self . context_y )
71
80
kl = 0.0
72
81
for _ in range (self .num_samples ):
73
82
# Taking reparameterized samples
@@ -77,8 +86,8 @@ def acquisition(self, candidate_x, context_x, context_y):
77
86
y_pred = self .model .decoder (candidate_x , samples )
78
87
79
88
# Combining context and candidate data
80
- combined_x = torch .cat ([context_x , candidate_x ], dim = 0 )
81
- combined_y = torch .cat ([context_y , y_pred ], dim = 0 )
89
+ combined_x = torch .cat ([self . context_x , candidate_x ], dim = 0 ). to ( device )
90
+ combined_y = torch .cat ([self . context_y , y_pred ], dim = 0 ). to ( device )
82
91
83
92
# Computing posterior variables
84
93
z_mu_posterior , z_logvar_posterior = self .model .data_to_z_params (combined_x , combined_y )
0 commit comments