File tree Expand file tree Collapse file tree 1 file changed +2
-12
lines changed Expand file tree Collapse file tree 1 file changed +2
-12
lines changed Original file line number Diff line number Diff line change @@ -41,23 +41,13 @@ def test_acquisition_shape(self):
41
41
candidate_x = self .candidate_x
42
42
)
43
43
self .assertTrue (torch .is_tensor (lig_score ))
44
- self .assertEqual (lig_score .shape , ())
44
+ self .assertEqual (lig_score .shape , (1 , 5 ))
45
45
46
46
def test_acquisition_kl (self ):
47
47
lig_score = self .acquisition_function .forward (
48
48
candidate_x = self .candidate_x
49
49
)
50
- self .assertGreaterEqual (lig_score .item (), 0 )
51
-
52
- def test_acquisition_samples (self ):
53
- lig_1 = self .acquisition_function .forward (candidate_x = self .candidate_x )
54
-
55
- self .acquisition_function .num_samples = 20
56
- lig_2 = self .acquisition_function .forward (candidate_x = self .candidate_x )
57
- self .assertTrue (lig_2 .item () < lig_1 .item ())
58
- self .assertTrue (abs (lig_2 .item () - lig_1 .item ()) < 0.2 )
59
-
60
-
50
+ self .assertGreaterEqual (lig_score .mean ().item (), 0 )
61
51
62
52
if __name__ == "__main__" :
63
53
unittest .main ()
You can’t perform that action at this time.
0 commit comments