Skip to content

Commit 6b75672

Browse files
esantorellafacebook-github-bot
authored andcommitted
Fix bug in optimize_objective with fixed features (#2691)
Summary: Pull Request resolved: #2691 **Context:** As per #2686, bounds for `optimize_acqf` are not constructed correctly in `optimize_objective`, which is used in input constructors for qKG-type acquisition functions. This issue wasn't surfaced by unit tests because `optimize_acqf` was mocked out. In the process of shoring up the test, I discovered a second bug: This `optimize_objective` doesn't work with constraints, because the optimizer is set to be L-BFGS-B when it isn't otherwise specified, and L-BFGS-B doesn't work with BoTorch-style constraints (only simple box constraints, aka BoTorch bounds). So I guess the input constructors for qKG-style acquisition functions haven't been working with fixed features or with constraints for a long time -- both usages would just error. The existing unit test should have caught this but didn't due to use of mocks, so I removed the mocking. **Changes:** In `optimize_objective`: * Use `bounds.shape` instead of `len(bounds)` when constructing a list of features for `fixed_features_list` * Don't specify 'method' if the user doesn't pass it, so it can be automatically chosen based on the presence of constraints. Other: * In `optimize_acqf`, cleaned up some logic. This doesn't have any effect on behavior. * Added a type annotation Reviewed By: saitcakmak Differential Revision: D68464825 fbshipit-source-id: 08eb4e0f6f91c96e572650b2c00a5515e9497e36
1 parent 589260b commit 6b75672

File tree

4 files changed

+101
-47
lines changed

4 files changed

+101
-47
lines changed

botorch/acquisition/input_constructors.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1758,7 +1758,7 @@ def optimize_objective(
17581758
columns=list(fixed_features.keys()),
17591759
values=list(fixed_features.values()),
17601760
)
1761-
free_feature_dims = list(range(len(bounds)) - fixed_features.keys())
1761+
free_feature_dims = list(range(bounds.shape[1]) - fixed_features.keys())
17621762
free_feature_bounds = bounds[:, free_feature_dims] # (2, d' <= d)
17631763
else:
17641764
free_feature_bounds = bounds
@@ -1775,18 +1775,21 @@ def optimize_objective(
17751775
rhs = -b[i, 0]
17761776
inequality_constraints.append((indices, coefficients, rhs))
17771777

1778+
options = {
1779+
"batch_limit": optimizer_options.get("batch_limit", 8),
1780+
"maxiter": optimizer_options.get("maxiter", 200),
1781+
"nonnegative": optimizer_options.get("nonnegative", False),
1782+
}
1783+
if "method" in optimizer_options:
1784+
options["method"] = optimizer_options.pop("method")
1785+
17781786
return optimize_acqf(
17791787
acq_function=acq_function,
17801788
bounds=free_feature_bounds,
17811789
q=q,
17821790
num_restarts=optimizer_options.get("num_restarts", 60),
17831791
raw_samples=optimizer_options.get("raw_samples", 1024),
1784-
options={
1785-
"batch_limit": optimizer_options.get("batch_limit", 8),
1786-
"maxiter": optimizer_options.get("maxiter", 200),
1787-
"nonnegative": optimizer_options.get("nonnegative", False),
1788-
"method": optimizer_options.get("method", "L-BFGS-B"),
1789-
},
1792+
options=options,
17901793
inequality_constraints=inequality_constraints,
17911794
fixed_features=None, # handled inside the acquisition function
17921795
post_processing_func=post_processing_func,

botorch/optim/optimize.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -362,14 +362,11 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
362362
)
363363

364364
bounds = opt_inputs.bounds
365-
gen_kwargs: dict[str, Any] = {
366-
"lower_bounds": None if bounds[0].isinf().all() else bounds[0],
367-
"upper_bounds": None if bounds[1].isinf().all() else bounds[1],
368-
"options": {k: v for k, v in options.items() if k not in INIT_OPTION_KEYS},
369-
"fixed_features": opt_inputs.fixed_features,
370-
"timeout_sec": timeout_sec,
371-
}
365+
lower_bounds = None if bounds[0].isinf().all() else bounds[0]
366+
upper_bounds = None if bounds[1].isinf().all() else bounds[1]
367+
gen_options = {k: v for k, v in options.items() if k not in INIT_OPTION_KEYS}
372368

369+
gen_kwargs = {}
373370
for constraint_name in [
374371
"inequality_constraints",
375372
"equality_constraints",
@@ -386,7 +383,14 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
386383
batch_candidates_curr,
387384
batch_acq_values_curr,
388385
) = opt_inputs.gen_candidates(
389-
batched_ics_, opt_inputs.acq_function, **gen_kwargs
386+
batched_ics_,
387+
opt_inputs.acq_function,
388+
lower_bounds=lower_bounds,
389+
upper_bounds=upper_bounds,
390+
options=gen_options,
391+
fixed_features=opt_inputs.fixed_features,
392+
timeout_sec=timeout_sec,
393+
**gen_kwargs,
390394
)
391395
opt_warnings += ws
392396
batch_candidates_list.append(batch_candidates_curr)
@@ -624,7 +628,7 @@ def optimize_acqf(
624628
retry_on_optimization_warning=retry_on_optimization_warning,
625629
ic_gen_kwargs=ic_gen_kwargs,
626630
)
627-
return _optimize_acqf(opt_acqf_inputs)
631+
return _optimize_acqf(opt_inputs=opt_acqf_inputs)
628632

629633

630634
def _optimize_acqf(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor]:

botorch/posteriors/gpytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(self, distribution: MultivariateNormal) -> None:
4545
MultitaskMultivariateNormal (multi-output case).
4646
"""
4747
super().__init__(distribution=distribution)
48-
self._is_mt = isinstance(distribution, MultitaskMultivariateNormal)
48+
self._is_mt: bool = isinstance(distribution, MultitaskMultivariateNormal)
4949

5050
@property
5151
def mvn(self) -> MultivariateNormal:
@@ -224,7 +224,7 @@ def scalarize_posterior_gpytorch(
224224
"""
225225
mean = posterior.mean
226226
q, m = mean.shape[-2:]
227-
_validate_scalarize_inputs(weights, m)
227+
_validate_scalarize_inputs(weights=weights, m=m)
228228
batch_shape = mean.shape[:-2]
229229
mvn = posterior.distribution
230230
cov = mvn.lazy_covariance_matrix if mvn.islazy else mvn.covariance_matrix

test/acquisition/test_input_constructors.py

Lines changed: 76 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import math
1515
from collections.abc import Callable
1616
from functools import reduce
17+
18+
from random import randint
1719
from unittest import mock
1820
from unittest.mock import MagicMock
1921

@@ -43,12 +45,14 @@
4345
get_acqf_input_constructor,
4446
get_best_f_analytic,
4547
get_best_f_mc,
48+
optimize_objective,
4649
)
4750
from botorch.acquisition.joint_entropy_search import qJointEntropySearch
4851
from botorch.acquisition.knowledge_gradient import (
4952
qKnowledgeGradient,
5053
qMultiFidelityKnowledgeGradient,
5154
)
55+
5256
from botorch.acquisition.logei import (
5357
qLogExpectedImprovement,
5458
qLogNoisyExpectedImprovement,
@@ -108,6 +112,7 @@
108112
from botorch.models import MultiTaskGP, SaasFullyBayesianSingleTaskGP, SingleTaskGP
109113
from botorch.models.deterministic import FixedSingleSampleModel
110114
from botorch.models.model_list_gp_regression import ModelListGP
115+
from botorch.optim.optimize import optimize_acqf
111116
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
112117
from botorch.test_utils.mock import mock_optimize
113118
from botorch.utils.constraints import get_outcome_constraint_transforms
@@ -221,38 +226,73 @@ def test_get_best_f_mc(self) -> None:
221226
best_f_expected = multi_Y.sum(dim=-1).max()
222227
self.assertAllClose(best_f, best_f_expected)
223228

224-
@mock.patch("botorch.acquisition.input_constructors.optimize_acqf")
225-
def test_optimize_objective(self, mock_optimize_acqf):
226-
from botorch.acquisition.input_constructors import optimize_objective
227-
228-
mock_model = self.mock_model
229-
bounds = torch.rand(2, len(self.bounds))
229+
@mock_optimize
230+
def test_optimize_objective(self) -> None:
231+
torch.manual_seed(randint(a=0, b=100))
232+
n = 4
233+
d = 3
234+
x = torch.rand(n, d, dtype=torch.double, device=self.device)
235+
y = torch.rand(n, 1, dtype=torch.double, device=self.device)
236+
model = SingleTaskGP(train_X=x, train_Y=y)
237+
238+
bounds = torch.tensor(
239+
[[0.0, -0.01, -0.02], [1.0, 1.01, 1.02]],
240+
dtype=torch.double,
241+
device=self.device,
242+
)
230243

231244
with self.subTest("scalarObjective_acquisitionFunction"):
232-
optimize_objective(
233-
model=mock_model,
234-
bounds=bounds,
235-
q=1,
236-
acq_function=UpperConfidenceBound(model=mock_model, beta=0.1),
237-
)
245+
acq_function = UpperConfidenceBound(model=model, beta=0.1)
246+
with mock.patch(
247+
"botorch.acquisition.input_constructors.optimize_acqf",
248+
wraps=optimize_acqf,
249+
) as mock_optimize_acqf:
250+
optimize_objective(
251+
model=model,
252+
bounds=bounds,
253+
q=1,
254+
acq_function=acq_function,
255+
)
238256
kwargs = mock_optimize_acqf.call_args[1]
239-
self.assertIsInstance(kwargs["acq_function"], UpperConfidenceBound)
257+
self.assertIs(kwargs["acq_function"], acq_function)
240258

241-
A = torch.rand(1, bounds.shape[-1])
242-
b = torch.zeros([1, 1])
259+
with self.subTest("Passing optimizer"):
260+
# Not testing for a more specific error message because the
261+
# exception comes from Scipy and they might change it
262+
with self.assertRaises(RuntimeWarning):
263+
optimize_objective(
264+
model=model,
265+
bounds=bounds,
266+
q=1,
267+
acq_function=acq_function,
268+
optimizer_options={"method": "throwing darts"},
269+
)
270+
271+
A = torch.rand(1, bounds.shape[-1], dtype=torch.double, device=self.device)
272+
b = torch.zeros([1, 1], dtype=torch.double, device=self.device)
243273
idx = A[0].nonzero(as_tuple=False).squeeze()
244274
inequality_constraints = ((idx, -A[0, idx], -b[0, 0]),)
245275

276+
m = 2
277+
y = torch.rand((n, m), dtype=torch.double, device=self.device)
278+
model = SingleTaskGP(train_X=x, train_Y=y)
279+
246280
with self.subTest("scalarObjective_linearConstraints"):
247-
post_tf = ScalarizedPosteriorTransform(weights=torch.rand(bounds.shape[-1]))
248-
_ = optimize_objective(
249-
model=mock_model,
250-
bounds=bounds,
251-
q=1,
252-
posterior_transform=post_tf,
253-
linear_constraints=(A, b),
254-
fixed_features=None,
281+
post_tf = ScalarizedPosteriorTransform(
282+
weights=torch.rand(m, dtype=torch.double, device=self.device)
255283
)
284+
with mock.patch(
285+
"botorch.acquisition.input_constructors.optimize_acqf",
286+
wraps=optimize_acqf,
287+
) as mock_optimize_acqf:
288+
_ = optimize_objective(
289+
model=model,
290+
bounds=bounds,
291+
q=1,
292+
posterior_transform=post_tf,
293+
linear_constraints=(A, b),
294+
fixed_features=None,
295+
)
256296

257297
kwargs = mock_optimize_acqf.call_args[1]
258298
self.assertIsInstance(kwargs["acq_function"], PosteriorMean)
@@ -264,13 +304,20 @@ def test_optimize_objective(self, mock_optimize_acqf):
264304
self.assertTrue(torch.equal(a, b))
265305

266306
with self.subTest("mcObjective_fixedFeatures"):
267-
_ = optimize_objective(
268-
model=mock_model,
269-
bounds=bounds,
270-
q=1,
271-
objective=LinearMCObjective(weights=torch.rand(bounds.shape[-1])),
272-
fixed_features={0: 0.5},
307+
objective = LinearMCObjective(
308+
weights=torch.rand(m, dtype=torch.double, device=self.device)
273309
)
310+
with mock.patch(
311+
"botorch.acquisition.input_constructors.optimize_acqf",
312+
wraps=optimize_acqf,
313+
) as mock_optimize_acqf:
314+
_ = optimize_objective(
315+
model=model,
316+
bounds=bounds,
317+
q=1,
318+
objective=objective,
319+
fixed_features={0: 0.5},
320+
)
274321

275322
kwargs = mock_optimize_acqf.call_args[1]
276323
self.assertIsInstance(

0 commit comments

Comments
 (0)