Skip to content

Commit e13f38c

Browse files
authored
Test LIG WIP
1 parent 32916b9 commit e13f38c

File tree

1 file changed

+14
-17
lines changed

1 file changed

+14
-17
lines changed
Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,50 @@
11
import unittest
2+
23
import torch
34
from botorch_community.acquisition.latent_information_gain import LatentInformationGain
45
from botorch_community.models.np_regression import NeuralProcessModel
56

7+
68
class TestLatentInformationGain(unittest.TestCase):
79
def setUp(self):
810
self.x_dim = 2
911
self.y_dim = 1
1012
self.r_dim = 8
11-
self.z_dim = 3
12-
self.context_x = torch.rand(10, self.x_dim)
13-
self.context_y = torch.rand(10, self.y_dim)
13+
self.z_dim = 8
1414
self.r_hidden_dims = [16, 16]
1515
self.z_hidden_dims = [32, 32]
1616
self.decoder_hidden_dims = [16, 16]
17-
self.num_samples = 10
1817
self.model = NeuralProcessModel(
19-
r_hidden_dims = self.r_hidden_dims,
20-
z_hidden_dims = self.z_hidden_dims,
21-
decoder_hidden_dims = self.decoder_hidden_dims,
18+
torch.rand(10, self.x_dim),
19+
torch.rand(10, self.y_dim),
20+
r_hidden_dims=self.r_hidden_dims,
21+
z_hidden_dims=self.z_hidden_dims,
22+
decoder_hidden_dims=self.decoder_hidden_dims,
2223
x_dim=self.x_dim,
2324
y_dim=self.y_dim,
2425
r_dim=self.r_dim,
2526
z_dim=self.z_dim,
2627
)
2728
self.acquisition_function = LatentInformationGain(
28-
context_x=self.context_x,
29-
context_y=self.context_y,
3029
model=self.model,
31-
num_samples=self.num_samples,
3230
)
3331
self.candidate_x = torch.rand(5, self.x_dim)
3432

3533
def test_initialization(self):
36-
self.assertEqual(self.acquisition_function.num_samples, self.num_samples)
34+
self.assertEqual(self.acquisition_function.num_samples, 10)
3735
self.assertEqual(self.acquisition_function.model, self.model)
3836

3937
def test_acquisition_shape(self):
40-
lig_score = self.acquisition_function.forward(
41-
candidate_x=self.candidate_x
42-
)
38+
self.model(self.model.train_X, self.model.train_Y)
39+
lig_score = self.acquisition_function.forward(candidate_x=self.candidate_x)
4340
self.assertTrue(torch.is_tensor(lig_score))
4441
self.assertEqual(lig_score.shape, (1, 5))
4542

4643
def test_acquisition_kl(self):
47-
lig_score = self.acquisition_function.forward(
48-
candidate_x=self.candidate_x
49-
)
44+
self.model(self.model.train_X, self.model.train_Y)
45+
lig_score = self.acquisition_function.forward(candidate_x=self.candidate_x)
5046
self.assertGreaterEqual(lig_score.mean().item(), 0)
5147

48+
5249
if __name__ == "__main__":
5350
unittest.main()

0 commit comments

Comments
 (0)