Skip to content

Commit 2c9c958

Browse files
authored
April Tests
1 parent 9949ff9 commit 2c9c958

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

test_community/acquisition/test_latent_information_gain.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22

33
import torch
4+
from botorch.optim.optimize import optimize_acqf
45
from botorch_community.acquisition.latent_information_gain import LatentInformationGain
56
from botorch_community.models.np_regression import NeuralProcessModel
67

@@ -32,16 +33,22 @@ def test_initialization(self):
3233
self.assertEqual(self.acquisition_function.num_samples, 10)
3334
self.assertEqual(self.acquisition_function.model, self.model)
3435

35-
def test_acquisition_shape(self):
36-
self.model(self.model.train_X, self.model.train_Y)
37-
lig_score = self.acquisition_function.forward(candidate_x=self.candidate_x)
38-
self.assertTrue(torch.is_tensor(lig_score))
39-
self.assertEqual(lig_score.shape, (1, 5))
40-
41-
def test_acquisition_kl(self):
42-
self.model(self.model.train_X, self.model.train_Y)
43-
lig_score = self.acquisition_function.forward(candidate_x=self.candidate_x)
44-
self.assertGreaterEqual(lig_score.mean().item(), 0)
36+
def test_acqf(self):
37+
bounds = torch.tensor([[0.0] * self.x_dim, [1.0] * self.x_dim])
38+
q = 3
39+
raw_samples = 8
40+
num_restarts = 2
41+
42+
candidate = optimize_acqf(
43+
acq_function=self.acquisition_function,
44+
bounds=bounds,
45+
q=q,
46+
raw_samples=raw_samples,
47+
num_restarts=num_restarts,
48+
)
49+
self.assertTrue(isinstance(candidate, tuple))
50+
self.assertEqual(candidate[0].shape, (q, self.x_dim))
51+
self.assertTrue(torch.all(candidate[1] >= 0))
4552

4653

4754
if __name__ == "__main__":

0 commit comments

Comments
 (0)