Skip to content

Commit 5f0dba0

Browse files
authored
Test NP Regression
1 parent 7a8e4ab commit 5f0dba0

File tree

1 file changed

+126
-0
lines changed

1 file changed

+126
-0
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import unittest
2+
import numpy as np
3+
import torch
4+
from torch import nn
5+
from torch.optim import Adam
6+
from botorch_community.models.np_regression import NeuralProcessModel
7+
from botorch.posteriors import GPyTorchPosterior
8+
from torch import Tensor
9+
10+
class TestNeuralProcessModel(unittest.TestCase):
11+
def initialize(self):
12+
self.r_hidden_dims = [16, 16]
13+
self.z_hidden_dims = [32, 32]
14+
self.decoder_hidden_dims = [16, 16]
15+
self.x_dim = 2
16+
self.y_dim = 1
17+
self.r_dim = 8
18+
self.z_dim = 8
19+
self.model = NeuralProcessModel(
20+
self.r_hidden_dims,
21+
self.z_hidden_dims,
22+
self.decoder_hidden_dims,
23+
self.x_dim,
24+
self.y_dim,
25+
self.r_dim,
26+
self.z_dim,
27+
)
28+
self.x_data = np.random.rand(100, self.x_dim)
29+
self.y_data = np.random.rand(100, self.y_dim)
30+
31+
def test_r_encoder(self):
32+
self.initialize()
33+
input = torch.rand(10, self.x_dim + self.y_dim)
34+
output = self.model.r_encoder(input)
35+
self.assertEqual(output.shape, (10, self.r_dim))
36+
self.assertTrue(torch.is_tensor(output))
37+
38+
def test_z_encoder(self):
39+
self.initialize()
40+
input = torch.rand(10, self.r_dim)
41+
mean, logvar = self.model.z_encoder(input)
42+
self.assertEqual(mean.shape, (10, self.z_dim))
43+
self.assertEqual(logvar.shape, (10, self.z_dim))
44+
self.assertTrue(torch.is_tensor(mean))
45+
self.assertTrue(torch.is_tensor(logvar))
46+
47+
def test_decoder(self):
48+
self.initialize()
49+
x_pred = torch.rand(10, self.x_dim)
50+
z = torch.rand(self.z_dim)
51+
output = self.model.decoder(x_pred, z)
52+
self.assertEqual(output.shape, (10, self.y_dim))
53+
self.assertTrue(torch.is_tensor(output))
54+
55+
def test_sample_z(self):
56+
self.initialize()
57+
mu = torch.rand(self.z_dim)
58+
logvar = torch.rand(self.z_dim)
59+
samples = self.model.sample_z(mu, logvar, n=5)
60+
self.assertEqual(samples.shape, (5, self.z_dim))
61+
self.assertTrue(torch.is_tensor(samples))
62+
63+
def test_KLD_gaussian(self):
64+
self.initialize()
65+
self.model.z_mu_all = torch.rand(self.z_dim)
66+
self.model.z_logvar_all = torch.rand(self.z_dim)
67+
self.model.z_mu_context = torch.rand(self.z_dim)
68+
self.model.z_logvar_context = torch.rand(self.z_dim)
69+
kld = self.model.KLD_gaussian()
70+
self.assertGreaterEqual(kld.item(), 0)
71+
self.assertTrue(torch.is_tensor(kld))
72+
73+
def test_data_to_z_params(self):
74+
self.initialize()
75+
x = torch.rand(10, self.x_dim)
76+
y = torch.rand(10, self.y_dim)
77+
mu, logvar = self.model.data_to_z_params(x, y)
78+
self.assertEqual(mu.shape, (self.z_dim,))
79+
self.assertEqual(logvar.shape, (self.z_dim,))
80+
self.assertTrue(torch.is_tensor(mu))
81+
self.assertTrue(torch.is_tensor(logvar))
82+
83+
def test_forward(self):
84+
self.initialize()
85+
x_t = torch.rand(5, self.x_dim)
86+
x_c = torch.rand(10, self.x_dim)
87+
y_c = torch.rand(10, self.y_dim)
88+
y_t = torch.rand(5, self.y_dim)
89+
output = self.model(x_t, x_c, y_c, y_t)
90+
self.assertEqual(output.shape, (5, self.y_dim))
91+
92+
def test_random_split_context_target(self):
93+
self.initialize()
94+
x_c, y_c, x_t, y_t = self.model.random_split_context_target(
95+
self.x_data[:, 0], self.y_data, 20, 0
96+
)
97+
self.assertEqual(x_c.shape[0], 20)
98+
self.assertEqual(y_c.shape[0], 20)
99+
self.assertEqual(x_t.shape[0], 80)
100+
self.assertEqual(y_t.shape[0], 80)
101+
102+
def test_posterior(self):
103+
self.initialize()
104+
x_t = torch.rand(5, self.x_dim)
105+
x_c = torch.rand(10, self.x_dim)
106+
y_c = torch.rand(10, self.y_dim)
107+
y_t = torch.rand(5, self.y_dim)
108+
output = self.model(x_t, x_c, y_c, y_t)
109+
posterior = self.model.posterior(x_t, 0.1, 0.01, observation_noise=True)
110+
self.assertIsInstance(posterior, GPyTorchPosterior)
111+
mvn = posterior.mvn
112+
self.assertEqual(mvn.covariance_matrix.size(), (5, 5, 5))
113+
114+
def test_load_state_dict(self):
115+
self.initialize()
116+
state_dict = {"r_encoder.mlp.model.0.bias": torch.rand(16)}
117+
self.model.load_state_dict(state_dict, strict = False)
118+
119+
def test_transform_inputs(self):
120+
self.initialize()
121+
X = torch.rand(5, 3)
122+
self.assertTrue(torch.equal(self.model.transform_inputs(X), X))
123+
124+
125+
if __name__ == "__main__":
126+
unittest.main()

0 commit comments

Comments
 (0)