Skip to content

Commit c55d7a9

Browse files
authored
Test Latent Information Gain
1 parent 5f0dba0 commit c55d7a9

File tree

1 file changed

+91
-0
lines changed

1 file changed

+91
-0
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import unittest
2+
import torch
3+
from torch import nn
4+
from torch.distributions import Normal
5+
from botorch_community.acquisition.latent_information_gain import LatentInformationGain
6+
from botorch_community.models.np_regression import NeuralProcessModel
7+
8+
class TestLatentInformationGain(unittest.TestCase):
9+
def setUp(self):
10+
self.x_dim = 2
11+
self.y_dim = 1
12+
self.r_dim = 8
13+
self.z_dim = 3
14+
self.r_hidden_dims = [16, 16]
15+
self.z_hidden_dims = [32, 32]
16+
self.decoder_hidden_dims = [16, 16]
17+
self.num_samples = 10
18+
self.model = NeuralProcessModel(
19+
r_hidden_dims = self.r_hidden_dims,
20+
z_hidden_dims = self.z_hidden_dims,
21+
decoder_hidden_dims = self.decoder_hidden_dims,
22+
x_dim=self.x_dim,
23+
y_dim=self.y_dim,
24+
r_dim=self.r_dim,
25+
z_dim=self.z_dim,
26+
)
27+
self.acquisition_function = LatentInformationGain(
28+
model=self.model,
29+
num_samples=self.num_samples,
30+
)
31+
self.context_x = torch.rand(10, self.x_dim)
32+
self.context_y = torch.rand(10, self.y_dim)
33+
self.candidate_x = torch.rand(5, self.x_dim)
34+
35+
def test_initialization(self):
36+
self.assertEqual(self.acquisition_function.num_samples, self.num_samples)
37+
self.assertEqual(self.acquisition_function.model, self.model)
38+
39+
def test_acquisition_shape(self):
40+
lig_score = self.acquisition_function.acquisition(
41+
candidate_x=self.candidate_x,
42+
context_x=self.context_x,
43+
context_y=self.context_y,
44+
)
45+
self.assertTrue(torch.is_tensor(lig_score))
46+
self.assertEqual(lig_score.shape, ())
47+
48+
def test_acquisition_kl(self):
49+
lig_score = self.acquisition_function.acquisition(
50+
candidate_x=self.candidate_x,
51+
context_x=self.context_x,
52+
context_y=self.context_y,
53+
)
54+
self.assertGreaterEqual(lig_score.item(), 0)
55+
56+
def test_acquisition_samples(self):
57+
lig_1 = self.acquisition_function.acquisition(
58+
candidate_x=self.candidate_x,
59+
context_x=self.context_x,
60+
context_y=self.context_y,
61+
)
62+
63+
self.acquisition_function.num_samples = 20
64+
lig_2 = self.acquisition_function.acquisition(
65+
candidate_x=self.candidate_x,
66+
context_x=self.context_x,
67+
context_y=self.context_y,
68+
)
69+
self.assertTrue(lig_2.item() < lig_1.item())
70+
self.assertTrue(abs(lig_2.item() - lig_1.item()) < 0.2)
71+
72+
def test_acquisition_invalid_inputs(self):
73+
invalid_context_x = torch.rand(10, self.x_dim + 5)
74+
with self.assertRaises(Exception):
75+
self.acquisition_function.acquisition(
76+
candidate_x=self.candidate_x,
77+
context_x=invalid_context_x,
78+
context_y=self.context_y,
79+
)
80+
81+
invalid_candidate_x = torch.rand(5, self.x_dim + 5)
82+
with self.assertRaises(Exception):
83+
self.acquisition_function.acquisition(
84+
candidate_x=invalid_candidate_x,
85+
context_x=self.context_x,
86+
context_y=self.context_y,
87+
)
88+
89+
90+
if __name__ == "__main__":
91+
unittest.main()

0 commit comments

Comments
 (0)