From 4d024fb8fb314562378be5075ea9d311f9d35821 Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Sat, 13 Sep 2025 17:39:16 +0100 Subject: [PATCH 01/13] Add `infeasible_obj` argument to qLogNEI --- botorch/acquisition/logei.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/botorch/acquisition/logei.py b/botorch/acquisition/logei.py index 62821e461e..7d9742ef66 100644 --- a/botorch/acquisition/logei.py +++ b/botorch/acquisition/logei.py @@ -282,6 +282,7 @@ def __init__( tau_relu: float = TAU_RELU, marginalize_dim: int | None = None, incremental: bool = True, + infeasible_obj: Tensor | float | None = None, ) -> None: r"""q-Noisy Expected Improvement. @@ -324,6 +325,9 @@ def __init__( incremental: Whether to compute incremental EI over the pending points or compute EI of the joint batch improvement (including pending points). + infeasible_obj: A Tensor to be used calculating the best objective when + no feasible points exist. If None, automatically calculate lower + bound on objective values from the GP posterior. TODO: similar to qNEHVI, when we are using sequential greedy candidate selection, we could incorporate pending points X_baseline and compute @@ -333,6 +337,7 @@ def __init__( # TODO: separate out baseline variables initialization and other functions # in qNEI to avoid duplication of both code and work at runtime. self.incremental = incremental + self.infeasible_obj = infeasible_obj super().__init__( model=model, @@ -570,6 +575,7 @@ def _compute_best_feasible_objective(self, samples: Tensor, obj: Tensor) -> Tens objective=self.objective, posterior_transform=self.posterior_transform, X_baseline=self.X_baseline, + infeasible_obj=self.infeasible_obj, ) From 40623f5f9450f9c11bc7e33bc21cf5d972cadb60 Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Sat, 13 Sep 2025 18:08:02 +0100 Subject: [PATCH 02/13] Replace convex weighted samples with uniform samples in bounding box --- botorch/acquisition/utils.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index 00fbe35291..507ffc3081 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -171,8 +171,9 @@ def _estimate_objective_lower_bound( posterior_transform: PosteriorTransform | None, X: Tensor, ) -> Tensor: - """Estimates a lower bound on the objective values by evaluating the model at convex - combinations of `X`, returning the 6-sigma lower bound of the computed statistics. + """Estimates a lower bound on the objective values by evaluating the at uniformly + random points in the bounding box of `X`, returning the 6-sigma lower bound of the + computed statistics. Args: model: A fitted model. @@ -183,19 +184,20 @@ def _estimate_objective_lower_bound( Returns: A `m`-dimensional Tensor of lower bounds of the objectives. """ - convex_weights = torch.rand( - 32, - X.shape[-2], - dtype=X.dtype, - device=X.device, - ) - weights_sum = convex_weights.sum(dim=0, keepdim=True) - convex_weights = convex_weights / weights_sum + # sample + # we do not have access to `bounds` here, so we infer the bounding box + # from data, expanding by 10% in each direction + X_lb = X.min(dim=-2) + X_ub = X.max(dim=-2) + X_range = X_ub - X_lb + X_padding = 0.1 * X_range + uniform_samples = torch.rand(32, X.shape[-1], dtype=X.dtype, device=X.device) + X_samples = X_lb - X_padding + uniform_samples * (X_range + X_padding) # infeasible cost M is such that -M < min_x f(x), thus # 0 < min_x f(x) - (-M), so we should take -M as a lower # bound on the best feasible objective return -get_infeasible_cost( - X=convex_weights @ X, + X=X_samples, model=model, objective=objective, posterior_transform=posterior_transform, From 670fbc666cc678e6f4eeee18d207c54e2ed4d490 Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Sat, 13 Sep 2025 19:28:27 +0100 Subject: [PATCH 03/13] Check upper and lower bound of posterior for objective evaluation --- botorch/acquisition/utils.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index 507ffc3081..16d4bbcf39 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -184,7 +184,6 @@ def _estimate_objective_lower_bound( Returns: A `m`-dimensional Tensor of lower bounds of the objectives. """ - # sample # we do not have access to `bounds` here, so we infer the bounding box # from data, expanding by 10% in each direction X_lb = X.min(dim=-2) @@ -237,8 +236,18 @@ def objective(Y: Tensor, X: Tensor | None = None): return Y.squeeze(-1) posterior = model.posterior(X, posterior_transform=posterior_transform) - lb = objective(posterior.mean - 6 * posterior.variance.clamp_min(0).sqrt(), X=X) - if lb.ndim < posterior.mean.ndim: + # We check both the upper and lower bound of the posterior, since the objective + # may be increasing or decreasing. For objectives that are neither (eg. absolute + # distance from a target), this should still provide a good bound. + lb = torch.stack( + [ + objective(posterior.mean - 6 * posterior.variance.clamp_min(0).sqrt(), X=X), + objective(posterior.mean + 6 * posterior.variance.clamp_min(0).sqrt(), X=X), + ], + dim=0, + ) + + if lb.ndim - 1 < posterior.mean.ndim: lb = lb.unsqueeze(-1) # Take outcome-wise min. Looping in to handle batched models. while lb.dim() > 1: From e7e9520a10cfe2e2a86a2664e910e6910c70a644 Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Sat, 13 Sep 2025 19:56:41 +0100 Subject: [PATCH 04/13] Improve `prune_inferior_points` when all points are infeasible --- botorch/acquisition/utils.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index 16d4bbcf39..27098c4b80 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -29,7 +29,10 @@ from botorch.sampling.base import MCSampler from botorch.sampling.get_sampler import get_sampler from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model -from botorch.utils.objective import compute_feasibility_indicator +from botorch.utils.objective import ( + compute_feasibility_indicator, + compute_smoothed_feasibility_indicator, +) from botorch.utils.sampling import optimize_posterior_samples from botorch.utils.transforms import is_ensemble, normalize_indices from gpytorch.models import GP @@ -385,7 +388,24 @@ def prune_inferior_points( sampler=sampler, marginalize_dim=marginalize_dim, ) - if infeas.any(): + if infeas.all(): + # if no points are feasible, keep the point closest to being feasible + with torch.no_grad(): + posterior = model.posterior(X=X, posterior_transform=posterior_transform) + if sampler is None: + sampler = get_sampler( + posterior=posterior, sample_shape=torch.Size([num_samples]) + ) + samples = sampler(posterior) + # use the probability of feasibility as the objective for computing best points + obj_vals = compute_smoothed_feasibility_indicator( + constraints=constraints, + samples=samples, + eta=1e-3, + log=True, + ) + + elif infeas.any(): # set infeasible points to worse than worst objective across all samples # Use clone() here to avoid deprecated `index_put_` on an expanded tensor obj_vals = obj_vals.clone() From b145f067c7cb2ac7fd77b4959bebd546d573181a Mon Sep 17 00:00:00 2001 From: Toby Boyne <48383196+TobyBoyne@users.noreply.github.com> Date: Thu, 18 Sep 2025 02:14:58 -0700 Subject: [PATCH 05/13] Correct padding size Co-authored-by: Max Balandat --- botorch/acquisition/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index 27098c4b80..f7143c4af9 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -194,7 +194,7 @@ def _estimate_objective_lower_bound( X_range = X_ub - X_lb X_padding = 0.1 * X_range uniform_samples = torch.rand(32, X.shape[-1], dtype=X.dtype, device=X.device) - X_samples = X_lb - X_padding + uniform_samples * (X_range + X_padding) + X_samples = X_lb - X_padding + uniform_samples * (X_range + 2 * X_padding) # infeasible cost M is such that -M < min_x f(x), thus # 0 < min_x f(x) - (-M), so we should take -M as a lower # bound on the best feasible objective From 1e0625acf2a6ad6175b0a8c1f5ab681543df0bd5 Mon Sep 17 00:00:00 2001 From: Toby Boyne <48383196+TobyBoyne@users.noreply.github.com> Date: Thu, 18 Sep 2025 02:15:23 -0700 Subject: [PATCH 06/13] Extract std calculation from lb Co-authored-by: Max Balandat --- botorch/acquisition/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index f7143c4af9..537ba33dd5 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -242,10 +242,11 @@ def objective(Y: Tensor, X: Tensor | None = None): # We check both the upper and lower bound of the posterior, since the objective # may be increasing or decreasing. For objectives that are neither (eg. absolute # distance from a target), this should still provide a good bound. + six_stdv = 6 * posterior.variance.clamp_min(0).sqrt() lb = torch.stack( [ - objective(posterior.mean - 6 * posterior.variance.clamp_min(0).sqrt(), X=X), - objective(posterior.mean + 6 * posterior.variance.clamp_min(0).sqrt(), X=X), + objective(posterior.mean - six_stdv, X=X), + objective(posterior.mean + six_stdv, X=X), ], dim=0, ) From 7a04f732f1eb6b9b562eb78262ec50651194c418 Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Tue, 7 Oct 2025 14:15:33 +0100 Subject: [PATCH 07/13] Change min to explicitly minimize the upper/lower bound first --- botorch/acquisition/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index 537ba33dd5..ffe75d112c 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -250,8 +250,9 @@ def objective(Y: Tensor, X: Tensor | None = None): ], dim=0, ) + lb = lb.min(dim=0).values - if lb.ndim - 1 < posterior.mean.ndim: + if lb.ndim < posterior.mean.ndim: lb = lb.unsqueeze(-1) # Take outcome-wise min. Looping in to handle batched models. while lb.dim() > 1: From cfffbe2713aba6f6169235d61744b119af5d3934 Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Tue, 7 Oct 2025 14:38:19 +0100 Subject: [PATCH 08/13] Add warning message to `_prune_inferior_shared_processing` when all points are infeasible --- botorch/acquisition/utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index ffe75d112c..e7bbee9255 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -11,6 +11,7 @@ from __future__ import annotations import math +import warnings from collections.abc import Callable import torch @@ -24,6 +25,7 @@ DeprecationError, UnsupportedError, ) +from botorch.exceptions.warnings import BotorchWarning from botorch.models.fully_bayesian import MCMC_DIM from botorch.models.model import Model from botorch.sampling.base import MCSampler @@ -327,6 +329,15 @@ def _prune_inferior_shared_processing( samples=samples, marginalize_dim=marginalize_dim, ) + + if infeas.all(): + warnings.warn( + "When all training points are infeasible, it is better to use " + "q(Log)ProbabilityOfFeasibility.", + BotorchWarning, + stacklevel=2, + ) + return max_points, obj_vals, infeas From f768aab7a17164ff28073b08431ce89e35e1f6cb Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Tue, 7 Oct 2025 15:56:45 +0100 Subject: [PATCH 09/13] Add test for no feasible points recommending ProbabilityOfFeasibility --- test/acquisition/test_logei.py | 17 +++++++++++++++++ test/acquisition/test_monte_carlo.py | 17 +++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/test/acquisition/test_logei.py b/test/acquisition/test_logei.py index 09bd67d253..aad82ca17d 100644 --- a/test/acquisition/test_logei.py +++ b/test/acquisition/test_logei.py @@ -589,6 +589,23 @@ def test_prune_baseline(self): self.assertTrue(torch.equal(acqf.X_baseline, X_baseline[[-1]])) self.assertEqual(kwargs["marginalize_dim"], -3) + # test warning if all samples are infeasible + samples3 = samples[:-1] + mm = MockModel( + MockPosterior( + samples=samples3, + ) + ) + with self.assertWarnsRegex(BotorchWarning, "ProbabilityOfFeasibility"): + acqf = qLogNoisyExpectedImprovement( + model=mm, + X_baseline=X_baseline[:-1], + prune_baseline=True, + cache_root=False, + objective=mc_obj, + constraints=constraints, + ) + def test_cache_root(self): sample_cached_path = ( "botorch.acquisition.cached_cholesky.sample_cached_cholesky" diff --git a/test/acquisition/test_monte_carlo.py b/test/acquisition/test_monte_carlo.py index 6e9db86e91..bb4beed591 100644 --- a/test/acquisition/test_monte_carlo.py +++ b/test/acquisition/test_monte_carlo.py @@ -469,6 +469,23 @@ def _test_prune_baseline(self, dtype: torch.dtype) -> None: self.assertTrue(torch.equal(acqf.X_baseline, X_baseline[[-1]])) self.assertEqual(kwargs["marginalize_dim"], -3) + # test warning if all samples are infeasible + samples3 = samples[:-1] + mm = MockModel( + MockPosterior( + samples=samples3, + ) + ) + with self.assertWarnsRegex(BotorchWarning, "ProbabilityOfFeasibility"): + acqf = qNoisyExpectedImprovement( + model=mm, + X_baseline=X_baseline[:-1], + prune_baseline=True, + cache_root=False, + objective=objective, + constraints=constraints, + ) + def test_cache_root(self): with catch_warnings(): simplefilter("ignore", category=NumericsWarning) From e3012c613fe289ceaa64f08d22fb6de648611433 Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Tue, 7 Oct 2025 16:00:30 +0100 Subject: [PATCH 10/13] Remove `infeasible_obj` argument --- botorch/acquisition/logei.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/botorch/acquisition/logei.py b/botorch/acquisition/logei.py index afea4ce240..c0b0fe21d4 100644 --- a/botorch/acquisition/logei.py +++ b/botorch/acquisition/logei.py @@ -282,7 +282,6 @@ def __init__( tau_relu: float = TAU_RELU, marginalize_dim: int | None = None, incremental: bool = True, - infeasible_obj: Tensor | float | None = None, ) -> None: r"""q-Noisy Expected Improvement. @@ -325,9 +324,6 @@ def __init__( incremental: Whether to compute incremental EI over the pending points or compute EI of the joint batch improvement (including pending points). - infeasible_obj: A Tensor to be used calculating the best objective when - no feasible points exist. If None, automatically calculate lower - bound on objective values from the GP posterior. TODO: similar to qNEHVI, when we are using sequential greedy candidate selection, we could incorporate pending points X_baseline and compute @@ -337,7 +333,6 @@ def __init__( # TODO: separate out baseline variables initialization and other functions # in qNEI to avoid duplication of both code and work at runtime. self.incremental = incremental - self.infeasible_obj = infeasible_obj super().__init__( model=model, @@ -575,7 +570,6 @@ def _compute_best_feasible_objective(self, samples: Tensor, obj: Tensor) -> Tens objective=self.objective, posterior_transform=self.posterior_transform, X_baseline=self.X_baseline, - infeasible_obj=self.infeasible_obj, ) From 1440623c9099ab80e3782dcebfda70f036287c47 Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Tue, 7 Oct 2025 16:15:37 +0100 Subject: [PATCH 11/13] Fix `min` and `max` functions not using `.values` --- botorch/acquisition/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index e7bbee9255..713172866f 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -191,8 +191,8 @@ def _estimate_objective_lower_bound( """ # we do not have access to `bounds` here, so we infer the bounding box # from data, expanding by 10% in each direction - X_lb = X.min(dim=-2) - X_ub = X.max(dim=-2) + X_lb = X.min(dim=-2).values + X_ub = X.max(dim=-2).values X_range = X_ub - X_lb X_padding = 0.1 * X_range uniform_samples = torch.rand(32, X.shape[-1], dtype=X.dtype, device=X.device) From db5303ff47e923de820f19291ea79297a274be06 Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Tue, 7 Oct 2025 16:59:09 +0100 Subject: [PATCH 12/13] Move warning to `compute_best_feasible_objective` --- botorch/acquisition/utils.py | 15 +++++++-------- test/acquisition/test_logei.py | 17 ----------------- test/acquisition/test_monte_carlo.py | 17 ----------------- test/acquisition/test_utils.py | 20 ++++++++++++-------- 4 files changed, 19 insertions(+), 50 deletions(-) diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index 713172866f..ed8143d851 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -155,6 +155,13 @@ def compute_best_feasible_objective( raise ValueError( "Must specify `X_baseline` when no feasible observation exists." ) + warnings.warn( + "When all training points are infeasible, it is better to use " + "q(Log)ProbabilityOfFeasibility.", + BotorchWarning, + stacklevel=2, + ) + infeasible_value = _estimate_objective_lower_bound( model=model, objective=objective, @@ -330,14 +337,6 @@ def _prune_inferior_shared_processing( marginalize_dim=marginalize_dim, ) - if infeas.all(): - warnings.warn( - "When all training points are infeasible, it is better to use " - "q(Log)ProbabilityOfFeasibility.", - BotorchWarning, - stacklevel=2, - ) - return max_points, obj_vals, infeas diff --git a/test/acquisition/test_logei.py b/test/acquisition/test_logei.py index aad82ca17d..09bd67d253 100644 --- a/test/acquisition/test_logei.py +++ b/test/acquisition/test_logei.py @@ -589,23 +589,6 @@ def test_prune_baseline(self): self.assertTrue(torch.equal(acqf.X_baseline, X_baseline[[-1]])) self.assertEqual(kwargs["marginalize_dim"], -3) - # test warning if all samples are infeasible - samples3 = samples[:-1] - mm = MockModel( - MockPosterior( - samples=samples3, - ) - ) - with self.assertWarnsRegex(BotorchWarning, "ProbabilityOfFeasibility"): - acqf = qLogNoisyExpectedImprovement( - model=mm, - X_baseline=X_baseline[:-1], - prune_baseline=True, - cache_root=False, - objective=mc_obj, - constraints=constraints, - ) - def test_cache_root(self): sample_cached_path = ( "botorch.acquisition.cached_cholesky.sample_cached_cholesky" diff --git a/test/acquisition/test_monte_carlo.py b/test/acquisition/test_monte_carlo.py index bb4beed591..6e9db86e91 100644 --- a/test/acquisition/test_monte_carlo.py +++ b/test/acquisition/test_monte_carlo.py @@ -469,23 +469,6 @@ def _test_prune_baseline(self, dtype: torch.dtype) -> None: self.assertTrue(torch.equal(acqf.X_baseline, X_baseline[[-1]])) self.assertEqual(kwargs["marginalize_dim"], -3) - # test warning if all samples are infeasible - samples3 = samples[:-1] - mm = MockModel( - MockPosterior( - samples=samples3, - ) - ) - with self.assertWarnsRegex(BotorchWarning, "ProbabilityOfFeasibility"): - acqf = qNoisyExpectedImprovement( - model=mm, - X_baseline=X_baseline[:-1], - prune_baseline=True, - cache_root=False, - objective=objective, - constraints=constraints, - ) - def test_cache_root(self): with catch_warnings(): simplefilter("ignore", category=NumericsWarning) diff --git a/test/acquisition/test_utils.py b/test/acquisition/test_utils.py index b8115ba0af..f071545de4 100644 --- a/test/acquisition/test_utils.py +++ b/test/acquisition/test_utils.py @@ -32,6 +32,7 @@ DeprecationError, UnsupportedError, ) +from botorch.exceptions.warnings import BotorchWarning from botorch.models import SingleTaskGP from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior from gpytorch.distributions import MultivariateNormal @@ -154,14 +155,17 @@ def test_compute_best_feasible_objective(self): def objective(Y, X): return Y.squeeze(-1) - 5.0 - best_f = compute_best_feasible_objective( - samples=samples, - obj=obj, - constraints=[lambda X: torch.ones_like(X[..., 0])], - model=mm, - X_baseline=X, - objective=objective, - ) + with self.assertWarnsRegex( + BotorchWarning, "ProbabilityOfFeasibility" + ): + best_f = compute_best_feasible_objective( + samples=samples, + obj=obj, + constraints=[lambda X: torch.ones_like(X[..., 0])], + model=mm, + X_baseline=X, + objective=objective, + ) expected_best_f = torch.full( sample_shape + batch_shape, -get_infeasible_cost(X=X, model=mm, objective=objective).item(), From bc23e8fba8de4ae1151a5e85a236fb3675a643e8 Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Wed, 8 Oct 2025 17:40:41 +0100 Subject: [PATCH 13/13] Fix handling of dimensions for batched inputs --- botorch/acquisition/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index ed8143d851..e11e8c7631 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -198,11 +198,13 @@ def _estimate_objective_lower_bound( """ # we do not have access to `bounds` here, so we infer the bounding box # from data, expanding by 10% in each direction - X_lb = X.min(dim=-2).values - X_ub = X.max(dim=-2).values + X_lb = X.min(dim=-2, keepdim=True).values + X_ub = X.max(dim=-2, keepdim=True).values X_range = X_ub - X_lb X_padding = 0.1 * X_range - uniform_samples = torch.rand(32, X.shape[-1], dtype=X.dtype, device=X.device) + uniform_samples = torch.rand( + *X.shape[:-2], 32, X.shape[-1], dtype=X.dtype, device=X.device + ) X_samples = X_lb - X_padding + uniform_samples * (X_range + 2 * X_padding) # infeasible cost M is such that -M < min_x f(x), thus # 0 < min_x f(x) - (-M), so we should take -M as a lower