File tree Expand file tree Collapse file tree 1 file changed +3
-9
lines changed Expand file tree Collapse file tree 1 file changed +3
-9
lines changed Original file line number Diff line number Diff line change 1
1
import unittest
2
2
import numpy as np
3
3
import torch
4
- from torch import nn
5
- from torch .optim import Adam
6
4
from botorch_community .models .np_regression import NeuralProcessModel
7
5
from botorch .posteriors import GPyTorchPosterior
8
- from torch import Tensor
6
+
7
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
9
8
10
9
class TestNeuralProcessModel (unittest .TestCase ):
11
10
def initialize (self ):
@@ -111,15 +110,10 @@ def test_posterior(self):
111
110
mvn = posterior .mvn
112
111
self .assertEqual (mvn .covariance_matrix .size (), (5 , 5 , 5 ))
113
112
114
- def test_load_state_dict (self ):
115
- self .initialize ()
116
- state_dict = {"r_encoder.mlp.model.0.bias" : torch .rand (16 )}
117
- self .model .load_state_dict (state_dict , strict = False )
118
-
119
113
def test_transform_inputs (self ):
120
114
self .initialize ()
121
115
X = torch .rand (5 , 3 )
122
- self .assertTrue (torch .equal (self .model .transform_inputs (X ), X ))
116
+ self .assertTrue (torch .equal (self .model .transform_inputs (X ), X . to ( device ) ))
123
117
124
118
125
119
if __name__ == "__main__" :
You can’t perform that action at this time.
0 commit comments