Skip to content

Commit d147bd5

Browse files
esantorellafacebook-github-bot
authored andcommitted
Stop test from stratified standardize transform from failing flakily (#2689)
Summary: Pull Request resolved: #2689 Context: This test has been failing flakily for two reasons: * Variance values get clamped when they are very small, causing a check to fail * Numerical tolerances were a little too tight in single precision Changes: * Test for both clamped and unclamped values * Loosened a tolerance * Made the seed be randomly chosen between 0 and 100 (inclusive) and checked that the test passes for all those values * Updated some checks to use `assertAllClose` for more informative errors Reviewed By: sdaulton Differential Revision: D68444531 fbshipit-source-id: 8f8ca4971b8d5d6b517c858282d3d703d08f8d83
1 parent 7ca0b2e commit d147bd5

File tree

1 file changed

+34
-29
lines changed

1 file changed

+34
-29
lines changed

test/models/transforms/test_outcome.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import itertools
88
from copy import deepcopy
9+
from random import randint
910

1011
import torch
1112
from botorch.models.transforms.outcome import (
@@ -24,6 +25,7 @@
2425
from botorch.posteriors import GPyTorchPosterior, TransformedPosterior
2526
from botorch.utils.testing import BotorchTestCase
2627
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
28+
from gpytorch.settings import min_variance
2729
from linear_operator.operators import (
2830
BlockDiagLinearOperator,
2931
DenseLinearOperator,
@@ -368,9 +370,12 @@ def test_standardize_state_dict(self):
368370

369371
def test_stratified_standardize(self):
370372
n = 5
373+
seed = randint(0, 100)
374+
torch.manual_seed(seed)
371375
for dtype, batch_shape in itertools.product(
372376
(torch.float, torch.double), (torch.Size([]), torch.Size([3]))
373377
):
378+
torch.manual_seed(seed)
374379
X = torch.rand(*batch_shape, n, 2, dtype=dtype, device=self.device)
375380
X[..., -1] = torch.tensor([0, 1, 0, 1, 0], dtype=dtype, device=self.device)
376381
Y = torch.randn(*batch_shape, n, 1, dtype=dtype, device=self.device)
@@ -389,38 +394,29 @@ def test_stratified_standardize(self):
389394
Y1 = Y[mask1].view(*batch_shape, -1, 1)
390395
Yvar1 = Yvar[mask1].view(*batch_shape, -1, 1)
391396
X1 = X[mask1].view(*batch_shape, -1, 1)
392-
tf0 = Standardize(
393-
m=1,
394-
batch_shape=batch_shape,
395-
)
397+
tf0 = Standardize(m=1, batch_shape=batch_shape)
396398
tf_Y0, tf_Yvar0 = tf0(Y=Y0, Yvar=Yvar0, X=X0)
397-
tf1 = Standardize(
398-
m=1,
399-
batch_shape=batch_shape,
400-
)
399+
tf1 = Standardize(m=1, batch_shape=batch_shape)
401400
tf_Y1, tf_Yvar1 = tf1(Y=Y1, Yvar=Yvar1, X=X1)
402401
# check that stratified means are expected
403-
self.assertTrue(torch.allclose(strata_tf.means[..., :1, :], tf0.means))
404-
self.assertTrue(torch.allclose(strata_tf.means[..., 1:, :], tf1.means))
405-
self.assertTrue(torch.allclose(strata_tf.stdvs[..., :1, :], tf0.stdvs))
406-
self.assertTrue(torch.allclose(strata_tf.stdvs[..., 1:, :], tf1.stdvs))
402+
self.assertAllClose(strata_tf.means[..., :1, :], tf0.means)
403+
self.assertAllClose(strata_tf.means[..., 1:, :], tf1.means)
404+
self.assertAllClose(strata_tf.stdvs[..., :1, :], tf0.stdvs)
405+
self.assertAllClose(strata_tf.stdvs[..., 1:, :], tf1.stdvs)
407406
# check the transformed values
408-
self.assertTrue(
409-
torch.allclose(tf_Y0, tf_Y[mask0].view(*batch_shape, -1, 1))
410-
)
411-
self.assertTrue(
412-
torch.allclose(tf_Y1, tf_Y[mask1].view(*batch_shape, -1, 1))
413-
)
414-
self.assertTrue(
415-
torch.allclose(tf_Yvar0, tf_Yvar[mask0].view(*batch_shape, -1, 1))
416-
)
417-
self.assertTrue(
418-
torch.allclose(tf_Yvar1, tf_Yvar[mask1].view(*batch_shape, -1, 1))
419-
)
407+
self.assertAllClose(tf_Y0, tf_Y[mask0].view(*batch_shape, -1, 1))
408+
self.assertAllClose(tf_Y1, tf_Y[mask1].view(*batch_shape, -1, 1))
409+
self.assertAllClose(tf_Yvar0, tf_Yvar[mask0].view(*batch_shape, -1, 1))
410+
self.assertAllClose(tf_Yvar1, tf_Yvar[mask1].view(*batch_shape, -1, 1))
420411
untf_Y, untf_Yvar = strata_tf.untransform(Y=tf_Y, Yvar=tf_Yvar, X=X)
421412
# test untransform
422-
self.assertTrue(torch.allclose(Y, untf_Y))
423-
self.assertTrue(torch.allclose(Yvar, untf_Yvar))
413+
if dtype == torch.float32:
414+
# defaults are 1e-5, 1e-8
415+
tols = {"rtol": 2e-5, "atol": 8e-8}
416+
else:
417+
tols = {}
418+
self.assertAllClose(Y, untf_Y, **tols)
419+
self.assertAllClose(Yvar, untf_Yvar)
424420

425421
# test untransform_posterior
426422
for lazy in (True, False):
@@ -434,14 +430,23 @@ def test_stratified_standardize(self):
434430
)
435431
p_utf = strata_tf.untransform_posterior(posterior, X=X)
436432
self.assertEqual(p_utf.device.type, self.device.type)
437-
self.assertTrue(p_utf.dtype == dtype)
433+
self.assertEqual(p_utf.dtype, dtype)
438434
strata_means, strata_stdvs, _ = strata_tf._get_per_input_means_stdvs(
439435
X=X, include_stdvs_sq=False
440436
)
441437
mean_expected = strata_means + strata_stdvs * posterior.mean
442-
variance_expected = strata_stdvs**2 * posterior.variance
438+
expected_raw_variance = (strata_stdvs**2 * posterior.variance).squeeze()
443439
self.assertAllClose(p_utf.mean, mean_expected)
444-
self.assertAllClose(p_utf.variance, variance_expected)
440+
# The variance will be clamped to a minimum (typically 1e-6), so
441+
# check both the raw values and clamped values
442+
raw_variance = p_utf.mvn.lazy_covariance_matrix.diagonal(
443+
dim1=-1, dim2=2
444+
)
445+
self.assertAllClose(raw_variance, expected_raw_variance)
446+
expected_clamped_variance = expected_raw_variance.clamp(
447+
min=min_variance.value(dtype=raw_variance.dtype)
448+
).unsqueeze(-1)
449+
self.assertAllClose(p_utf.variance, expected_clamped_variance)
445450
samples = p_utf.rsample()
446451
self.assertEqual(samples.shape, torch.Size([1]) + shape)
447452
samples = p_utf.rsample(sample_shape=torch.Size([4]))

0 commit comments

Comments
 (0)