Skip to content

Commit e7c964d

Browse files
authored
NPR Updated Tests
1 parent 4c037c5 commit e7c964d

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

test_community/models/test_np_regression.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
23
import torch
34
from botorch.posteriors import GPyTorchPosterior
45
from botorch_community.models.np_regression import NeuralProcessModel
@@ -15,6 +16,7 @@ def initialize(self):
1516
self.y_dim = 1
1617
self.r_dim = 8
1718
self.z_dim = 8
19+
self.n_context = 20
1820
self.model = NeuralProcessModel(
1921
torch.rand(100, self.x_dim),
2022
torch.rand(100, self.y_dim),
@@ -25,6 +27,7 @@ def initialize(self):
2527
self.y_dim,
2628
self.r_dim,
2729
self.z_dim,
30+
self.n_context
2831
)
2932

3033
def test_r_encoder(self):
@@ -71,10 +74,7 @@ def test_KLD_gaussian(self):
7174

7275
def test_data_to_z_params(self):
7376
self.initialize()
74-
mu, logvar = self.model.data_to_z_params(
75-
self.model.train_X,
76-
self.model.train_Y
77-
)
77+
mu, logvar = self.model.data_to_z_params(self.model.train_X, self.model.train_Y)
7878
self.assertEqual(mu.shape, (self.z_dim,))
7979
self.assertEqual(logvar.shape, (self.z_dim,))
8080
self.assertTrue(torch.is_tensor(mu))
@@ -88,7 +88,7 @@ def test_forward(self):
8888
def test_random_split_context_target(self):
8989
self.initialize()
9090
x_c, y_c, x_t, y_t = self.model.random_split_context_target(
91-
self.model.train_X[:, 0], self.model.train_Y
91+
self.model.train_X[:, 0], self.model.train_Y, self.model.n_context
9292
)
9393
self.assertEqual(x_c.shape[0], 20)
9494
self.assertEqual(y_c.shape[0], 20)

0 commit comments

Comments
 (0)