1
1
import unittest
2
2
import torch
3
- from torch import nn
4
- from torch .distributions import Normal
5
3
from botorch_community .acquisition .latent_information_gain import LatentInformationGain
6
4
from botorch_community .models .np_regression import NeuralProcessModel
7
5
@@ -11,6 +9,8 @@ def setUp(self):
11
9
self .y_dim = 1
12
10
self .r_dim = 8
13
11
self .z_dim = 3
12
+ self .context_x = torch .rand (10 , self .x_dim )
13
+ self .context_y = torch .rand (10 , self .y_dim )
14
14
self .r_hidden_dims = [16 , 16 ]
15
15
self .z_hidden_dims = [32 , 32 ]
16
16
self .decoder_hidden_dims = [16 , 16 ]
@@ -25,66 +25,38 @@ def setUp(self):
25
25
z_dim = self .z_dim ,
26
26
)
27
27
self .acquisition_function = LatentInformationGain (
28
+ context_x = self .context_x ,
29
+ context_y = self .context_y ,
28
30
model = self .model ,
29
31
num_samples = self .num_samples ,
30
32
)
31
- self .context_x = torch .rand (10 , self .x_dim )
32
- self .context_y = torch .rand (10 , self .y_dim )
33
33
self .candidate_x = torch .rand (5 , self .x_dim )
34
34
35
35
def test_initialization (self ):
36
36
self .assertEqual (self .acquisition_function .num_samples , self .num_samples )
37
37
self .assertEqual (self .acquisition_function .model , self .model )
38
38
39
39
def test_acquisition_shape (self ):
40
- lig_score = self .acquisition_function .acquisition (
41
- candidate_x = self .candidate_x ,
42
- context_x = self .context_x ,
43
- context_y = self .context_y ,
40
+ lig_score = self .acquisition_function .forward (
41
+ candidate_x = self .candidate_x
44
42
)
45
43
self .assertTrue (torch .is_tensor (lig_score ))
46
44
self .assertEqual (lig_score .shape , ())
47
45
48
46
def test_acquisition_kl (self ):
49
- lig_score = self .acquisition_function .acquisition (
50
- candidate_x = self .candidate_x ,
51
- context_x = self .context_x ,
52
- context_y = self .context_y ,
47
+ lig_score = self .acquisition_function .forward (
48
+ candidate_x = self .candidate_x
53
49
)
54
50
self .assertGreaterEqual (lig_score .item (), 0 )
55
51
56
52
def test_acquisition_samples (self ):
57
- lig_1 = self .acquisition_function .acquisition (
58
- candidate_x = self .candidate_x ,
59
- context_x = self .context_x ,
60
- context_y = self .context_y ,
61
- )
53
+ lig_1 = self .acquisition_function .forward (candidate_x = self .candidate_x )
62
54
63
55
self .acquisition_function .num_samples = 20
64
- lig_2 = self .acquisition_function .acquisition (
65
- candidate_x = self .candidate_x ,
66
- context_x = self .context_x ,
67
- context_y = self .context_y ,
68
- )
56
+ lig_2 = self .acquisition_function .forward (candidate_x = self .candidate_x )
69
57
self .assertTrue (lig_2 .item () < lig_1 .item ())
70
58
self .assertTrue (abs (lig_2 .item () - lig_1 .item ()) < 0.2 )
71
59
72
- def test_acquisition_invalid_inputs (self ):
73
- invalid_context_x = torch .rand (10 , self .x_dim + 5 )
74
- with self .assertRaises (Exception ):
75
- self .acquisition_function .acquisition (
76
- candidate_x = self .candidate_x ,
77
- context_x = invalid_context_x ,
78
- context_y = self .context_y ,
79
- )
80
-
81
- invalid_candidate_x = torch .rand (5 , self .x_dim + 5 )
82
- with self .assertRaises (Exception ):
83
- self .acquisition_function .acquisition (
84
- candidate_x = invalid_candidate_x ,
85
- context_x = self .context_x ,
86
- context_y = self .context_y ,
87
- )
88
60
89
61
90
62
if __name__ == "__main__" :
0 commit comments