6
6
7
7
import itertools
8
8
from copy import deepcopy
9
+ from random import randint
9
10
10
11
import torch
11
12
from botorch .models .transforms .outcome import (
24
25
from botorch .posteriors import GPyTorchPosterior , TransformedPosterior
25
26
from botorch .utils .testing import BotorchTestCase
26
27
from gpytorch .distributions import MultitaskMultivariateNormal , MultivariateNormal
28
+ from gpytorch .settings import min_variance
27
29
from linear_operator .operators import (
28
30
BlockDiagLinearOperator ,
29
31
DenseLinearOperator ,
@@ -368,9 +370,12 @@ def test_standardize_state_dict(self):
368
370
369
371
def test_stratified_standardize (self ):
370
372
n = 5
373
+ seed = randint (0 , 100 )
374
+ torch .manual_seed (seed )
371
375
for dtype , batch_shape in itertools .product (
372
376
(torch .float , torch .double ), (torch .Size ([]), torch .Size ([3 ]))
373
377
):
378
+ torch .manual_seed (seed )
374
379
X = torch .rand (* batch_shape , n , 2 , dtype = dtype , device = self .device )
375
380
X [..., - 1 ] = torch .tensor ([0 , 1 , 0 , 1 , 0 ], dtype = dtype , device = self .device )
376
381
Y = torch .randn (* batch_shape , n , 1 , dtype = dtype , device = self .device )
@@ -389,38 +394,29 @@ def test_stratified_standardize(self):
389
394
Y1 = Y [mask1 ].view (* batch_shape , - 1 , 1 )
390
395
Yvar1 = Yvar [mask1 ].view (* batch_shape , - 1 , 1 )
391
396
X1 = X [mask1 ].view (* batch_shape , - 1 , 1 )
392
- tf0 = Standardize (
393
- m = 1 ,
394
- batch_shape = batch_shape ,
395
- )
397
+ tf0 = Standardize (m = 1 , batch_shape = batch_shape )
396
398
tf_Y0 , tf_Yvar0 = tf0 (Y = Y0 , Yvar = Yvar0 , X = X0 )
397
- tf1 = Standardize (
398
- m = 1 ,
399
- batch_shape = batch_shape ,
400
- )
399
+ tf1 = Standardize (m = 1 , batch_shape = batch_shape )
401
400
tf_Y1 , tf_Yvar1 = tf1 (Y = Y1 , Yvar = Yvar1 , X = X1 )
402
401
# check that stratified means are expected
403
- self .assertTrue ( torch . allclose ( strata_tf .means [..., :1 , :], tf0 .means ) )
404
- self .assertTrue ( torch . allclose ( strata_tf .means [..., 1 :, :], tf1 .means ) )
405
- self .assertTrue ( torch . allclose ( strata_tf .stdvs [..., :1 , :], tf0 .stdvs ) )
406
- self .assertTrue ( torch . allclose ( strata_tf .stdvs [..., 1 :, :], tf1 .stdvs ) )
402
+ self .assertAllClose ( strata_tf .means [..., :1 , :], tf0 .means )
403
+ self .assertAllClose ( strata_tf .means [..., 1 :, :], tf1 .means )
404
+ self .assertAllClose ( strata_tf .stdvs [..., :1 , :], tf0 .stdvs )
405
+ self .assertAllClose ( strata_tf .stdvs [..., 1 :, :], tf1 .stdvs )
407
406
# check the transformed values
408
- self .assertTrue (
409
- torch .allclose (tf_Y0 , tf_Y [mask0 ].view (* batch_shape , - 1 , 1 ))
410
- )
411
- self .assertTrue (
412
- torch .allclose (tf_Y1 , tf_Y [mask1 ].view (* batch_shape , - 1 , 1 ))
413
- )
414
- self .assertTrue (
415
- torch .allclose (tf_Yvar0 , tf_Yvar [mask0 ].view (* batch_shape , - 1 , 1 ))
416
- )
417
- self .assertTrue (
418
- torch .allclose (tf_Yvar1 , tf_Yvar [mask1 ].view (* batch_shape , - 1 , 1 ))
419
- )
407
+ self .assertAllClose (tf_Y0 , tf_Y [mask0 ].view (* batch_shape , - 1 , 1 ))
408
+ self .assertAllClose (tf_Y1 , tf_Y [mask1 ].view (* batch_shape , - 1 , 1 ))
409
+ self .assertAllClose (tf_Yvar0 , tf_Yvar [mask0 ].view (* batch_shape , - 1 , 1 ))
410
+ self .assertAllClose (tf_Yvar1 , tf_Yvar [mask1 ].view (* batch_shape , - 1 , 1 ))
420
411
untf_Y , untf_Yvar = strata_tf .untransform (Y = tf_Y , Yvar = tf_Yvar , X = X )
421
412
# test untransform
422
- self .assertTrue (torch .allclose (Y , untf_Y ))
423
- self .assertTrue (torch .allclose (Yvar , untf_Yvar ))
413
+ if dtype == torch .float32 :
414
+ # defaults are 1e-5, 1e-8
415
+ tols = {"rtol" : 2e-5 , "atol" : 8e-8 }
416
+ else :
417
+ tols = {}
418
+ self .assertAllClose (Y , untf_Y , ** tols )
419
+ self .assertAllClose (Yvar , untf_Yvar )
424
420
425
421
# test untransform_posterior
426
422
for lazy in (True , False ):
@@ -434,14 +430,23 @@ def test_stratified_standardize(self):
434
430
)
435
431
p_utf = strata_tf .untransform_posterior (posterior , X = X )
436
432
self .assertEqual (p_utf .device .type , self .device .type )
437
- self .assertTrue (p_utf .dtype == dtype )
433
+ self .assertEqual (p_utf .dtype , dtype )
438
434
strata_means , strata_stdvs , _ = strata_tf ._get_per_input_means_stdvs (
439
435
X = X , include_stdvs_sq = False
440
436
)
441
437
mean_expected = strata_means + strata_stdvs * posterior .mean
442
- variance_expected = strata_stdvs ** 2 * posterior .variance
438
+ expected_raw_variance = ( strata_stdvs ** 2 * posterior .variance ). squeeze ()
443
439
self .assertAllClose (p_utf .mean , mean_expected )
444
- self .assertAllClose (p_utf .variance , variance_expected )
440
+ # The variance will be clamped to a minimum (typically 1e-6), so
441
+ # check both the raw values and clamped values
442
+ raw_variance = p_utf .mvn .lazy_covariance_matrix .diagonal (
443
+ dim1 = - 1 , dim2 = 2
444
+ )
445
+ self .assertAllClose (raw_variance , expected_raw_variance )
446
+ expected_clamped_variance = expected_raw_variance .clamp (
447
+ min = min_variance .value (dtype = raw_variance .dtype )
448
+ ).unsqueeze (- 1 )
449
+ self .assertAllClose (p_utf .variance , expected_clamped_variance )
445
450
samples = p_utf .rsample ()
446
451
self .assertEqual (samples .shape , torch .Size ([1 ]) + shape )
447
452
samples = p_utf .rsample (sample_shape = torch .Size ([4 ]))
0 commit comments