Skip to content

Commit a811429

Browse files
authored
1/25 Updates
1 parent 280776d commit a811429

File tree

1 file changed

+10
-38
lines changed

1 file changed

+10
-38
lines changed

test_community/acquisition/test_latent_information_gain.py

Lines changed: 10 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import unittest
22
import torch
3-
from torch import nn
4-
from torch.distributions import Normal
53
from botorch_community.acquisition.latent_information_gain import LatentInformationGain
64
from botorch_community.models.np_regression import NeuralProcessModel
75

@@ -11,6 +9,8 @@ def setUp(self):
119
self.y_dim = 1
1210
self.r_dim = 8
1311
self.z_dim = 3
12+
self.context_x = torch.rand(10, self.x_dim)
13+
self.context_y = torch.rand(10, self.y_dim)
1414
self.r_hidden_dims = [16, 16]
1515
self.z_hidden_dims = [32, 32]
1616
self.decoder_hidden_dims = [16, 16]
@@ -25,66 +25,38 @@ def setUp(self):
2525
z_dim=self.z_dim,
2626
)
2727
self.acquisition_function = LatentInformationGain(
28+
context_x=self.context_x,
29+
context_y=self.context_y,
2830
model=self.model,
2931
num_samples=self.num_samples,
3032
)
31-
self.context_x = torch.rand(10, self.x_dim)
32-
self.context_y = torch.rand(10, self.y_dim)
3333
self.candidate_x = torch.rand(5, self.x_dim)
3434

3535
def test_initialization(self):
3636
self.assertEqual(self.acquisition_function.num_samples, self.num_samples)
3737
self.assertEqual(self.acquisition_function.model, self.model)
3838

3939
def test_acquisition_shape(self):
40-
lig_score = self.acquisition_function.acquisition(
41-
candidate_x=self.candidate_x,
42-
context_x=self.context_x,
43-
context_y=self.context_y,
40+
lig_score = self.acquisition_function.forward(
41+
candidate_x=self.candidate_x
4442
)
4543
self.assertTrue(torch.is_tensor(lig_score))
4644
self.assertEqual(lig_score.shape, ())
4745

4846
def test_acquisition_kl(self):
49-
lig_score = self.acquisition_function.acquisition(
50-
candidate_x=self.candidate_x,
51-
context_x=self.context_x,
52-
context_y=self.context_y,
47+
lig_score = self.acquisition_function.forward(
48+
candidate_x=self.candidate_x
5349
)
5450
self.assertGreaterEqual(lig_score.item(), 0)
5551

5652
def test_acquisition_samples(self):
57-
lig_1 = self.acquisition_function.acquisition(
58-
candidate_x=self.candidate_x,
59-
context_x=self.context_x,
60-
context_y=self.context_y,
61-
)
53+
lig_1 = self.acquisition_function.forward(candidate_x=self.candidate_x)
6254

6355
self.acquisition_function.num_samples = 20
64-
lig_2 = self.acquisition_function.acquisition(
65-
candidate_x=self.candidate_x,
66-
context_x=self.context_x,
67-
context_y=self.context_y,
68-
)
56+
lig_2 = self.acquisition_function.forward(candidate_x=self.candidate_x)
6957
self.assertTrue(lig_2.item() < lig_1.item())
7058
self.assertTrue(abs(lig_2.item() - lig_1.item()) < 0.2)
7159

72-
def test_acquisition_invalid_inputs(self):
73-
invalid_context_x = torch.rand(10, self.x_dim + 5)
74-
with self.assertRaises(Exception):
75-
self.acquisition_function.acquisition(
76-
candidate_x=self.candidate_x,
77-
context_x=invalid_context_x,
78-
context_y=self.context_y,
79-
)
80-
81-
invalid_candidate_x = torch.rand(5, self.x_dim + 5)
82-
with self.assertRaises(Exception):
83-
self.acquisition_function.acquisition(
84-
candidate_x=invalid_candidate_x,
85-
context_x=self.context_x,
86-
context_y=self.context_y,
87-
)
8860

8961

9062
if __name__ == "__main__":

0 commit comments

Comments
 (0)