Skip to content

Commit b07772d

Browse files
sdaultonfacebook-github-bot
authored andcommitted
add abstract fully bayesian GP (#2696)
Summary: Pull Request resolved: #2696 see title. This enables supporting different fully bayesian models. Reviewed By: saitcakmak Differential Revision: D68529434 fbshipit-source-id: b739112924e3eeb061794c001c6c99c45bfb07fe
1 parent 0df5521 commit b07772d

File tree

2 files changed

+74
-51
lines changed

2 files changed

+74
-51
lines changed

botorch/acquisition/joint_entropy_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from botorch import settings
3030
from botorch.acquisition.acquisition import AcquisitionFunction, MCSamplerMixin
3131
from botorch.acquisition.objective import PosteriorTransform
32-
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
32+
from botorch.models.fully_bayesian import FullyBayesianSingleTaskGP
3333
from botorch.models.model import Model
3434
from botorch.models.utils import check_no_nans, fantasize as fantasize_flag
3535
from botorch.models.utils.gpytorch_modules import MIN_INFERRED_NOISE_LEVEL
@@ -124,7 +124,7 @@ def __init__(
124124
# and the optimal outputs have shapes num_optima x [num_models if FB] x 1 x 1
125125
# The third dimension equaling 1 is required to get one optimum per model,
126126
# which raises a BotorchTensorDimensionWarning.
127-
if isinstance(model, SaasFullyBayesianSingleTaskGP):
127+
if isinstance(model, FullyBayesianSingleTaskGP):
128128
raise NotImplementedError(FULLY_BAYESIAN_ERROR_MSG)
129129
with warnings.catch_warnings():
130130
warnings.filterwarnings("ignore")

botorch/models/fully_bayesian.py

Lines changed: 72 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
"""
3232

3333
import math
34-
from abc import abstractmethod
34+
from abc import ABC, abstractmethod
3535
from collections.abc import Mapping
3636
from typing import Any
3737

@@ -311,14 +311,13 @@ def load_mcmc_samples(
311311
return mean_module, covar_module, likelihood
312312

313313

314-
class SaasFullyBayesianSingleTaskGP(ExactGP, BatchedMultiOutputGPyTorchModel):
315-
r"""A fully Bayesian single-task GP model with the SAAS prior.
314+
class FullyBayesianSingleTaskGP(ExactGP, BatchedMultiOutputGPyTorchModel, ABC):
315+
r"""An abstract fully Bayesian single-task GP model.
316316
317317
This model assumes that the inputs have been normalized to [0, 1]^d and that
318318
the output has been standardized to have zero mean and unit variance. You can
319319
either normalize and standardize the data before constructing the model or use
320-
an `input_transform` and `outcome_transform`. The SAAS model [Eriksson2021saasbo]_
321-
with a Matern-5/2 kernel is used by default.
320+
an `input_transform` and `outcome_transform`.
322321
323322
You are expected to use `fit_fully_bayesian_model_nuts` to fit this model as it
324323
isn't compatible with `fit_gpytorch_mll`.
@@ -412,17 +411,9 @@ def _check_if_fitted(self):
412411
)
413412

414413
@property
415-
def median_lengthscale(self) -> Tensor:
416-
r"""Median lengthscales across the MCMC samples."""
417-
self._check_if_fitted()
418-
lengthscale = self.covar_module.base_kernel.lengthscale.clone()
419-
return lengthscale.median(0).values.squeeze(0)
420-
421-
@property
414+
@abstractmethod
422415
def num_mcmc_samples(self) -> int:
423416
r"""Number of MCMC samples in the model."""
424-
self._check_if_fitted()
425-
return len(self.covar_module.outputscale)
426417

427418
@property
428419
def batch_shape(self) -> torch.Size:
@@ -459,41 +450,6 @@ def load_mcmc_samples(self, mcmc_samples: dict[str, Tensor]) -> None:
459450
self.likelihood,
460451
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
461452

462-
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
463-
r"""Custom logic for loading the state dict.
464-
465-
The standard approach of calling `load_state_dict` currently doesn't play well
466-
with the `SaasFullyBayesianSingleTaskGP` since the `mean_module`, `covar_module`
467-
and `likelihood` aren't initialized until the model has been fitted. The reason
468-
for this is that we don't know the number of MCMC samples until NUTS is called.
469-
Given the state dict, we can initialize a new model with some dummy samples and
470-
then load the state dict into this model. This currently only works for a
471-
`SaasPyroModel` and supporting more Pyro models likely requires moving the model
472-
construction logic into the Pyro model itself.
473-
"""
474-
475-
if not isinstance(self.pyro_model, SaasPyroModel):
476-
raise NotImplementedError("load_state_dict only works for SaasPyroModel")
477-
raw_mean = state_dict["mean_module.raw_constant"]
478-
num_mcmc_samples = len(raw_mean)
479-
dim = self.pyro_model.train_X.shape[-1]
480-
tkwargs = {"device": raw_mean.device, "dtype": raw_mean.dtype}
481-
# Load some dummy samples
482-
mcmc_samples = {
483-
"mean": torch.ones(num_mcmc_samples, **tkwargs),
484-
"lengthscale": torch.ones(num_mcmc_samples, dim, **tkwargs),
485-
"outputscale": torch.ones(num_mcmc_samples, **tkwargs),
486-
}
487-
if self.pyro_model.train_Yvar is None:
488-
mcmc_samples["noise"] = torch.ones(num_mcmc_samples, **tkwargs)
489-
(
490-
self.mean_module,
491-
self.covar_module,
492-
self.likelihood,
493-
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
494-
# Load the actual samples from the state dict
495-
super().load_state_dict(state_dict=state_dict, strict=strict)
496-
497453
def forward(self, X: Tensor) -> MultivariateNormal:
498454
"""
499455
Unlike in other classes' `forward` methods, there is no `if self.training`
@@ -579,3 +535,70 @@ def condition_on_observations(
579535
X = X.repeat(*(Y.shape[:-2] + (1, 1)))
580536

581537
return super().condition_on_observations(X, Y, **kwargs)
538+
539+
540+
class SaasFullyBayesianSingleTaskGP(FullyBayesianSingleTaskGP):
541+
r"""A fully Bayesian single-task GP model with the SAAS prior.
542+
543+
This model assumes that the inputs have been normalized to [0, 1]^d and that
544+
the output has been standardized to have zero mean and unit variance. You can
545+
either normalize and standardize the data before constructing the model or use
546+
an `input_transform` and `outcome_transform`. The SAAS model [Eriksson2021saasbo]_
547+
with a Matern-5/2 kernel is used by default.
548+
549+
You are expected to use `fit_fully_bayesian_model_nuts` to fit this model as it
550+
isn't compatible with `fit_gpytorch_mll`.
551+
552+
Example:
553+
>>> saas_gp = SaasFullyBayesianSingleTaskGP(train_X, train_Y)
554+
>>> fit_fully_bayesian_model_nuts(saas_gp)
555+
>>> posterior = saas_gp.posterior(test_X)
556+
"""
557+
558+
@property
559+
def num_mcmc_samples(self) -> int:
560+
r"""Number of MCMC samples in the model."""
561+
self._check_if_fitted()
562+
return len(self.covar_module.outputscale)
563+
564+
@property
565+
def median_lengthscale(self) -> Tensor:
566+
r"""Median lengthscales across the MCMC samples."""
567+
self._check_if_fitted()
568+
lengthscale = self.covar_module.base_kernel.lengthscale.clone()
569+
return lengthscale.median(0).values.squeeze(0)
570+
571+
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
572+
r"""Custom logic for loading the state dict.
573+
574+
The standard approach of calling `load_state_dict` currently doesn't play well
575+
with the `SaasFullyBayesianSingleTaskGP` since the `mean_module`, `covar_module`
576+
and `likelihood` aren't initialized until the model has been fitted. The reason
577+
for this is that we don't know the number of MCMC samples until NUTS is called.
578+
Given the state dict, we can initialize a new model with some dummy samples and
579+
then load the state dict into this model. This currently only works for a
580+
`SaasPyroModel` and supporting more Pyro models likely requires moving the model
581+
construction logic into the Pyro model itself.
582+
"""
583+
584+
if not isinstance(self.pyro_model, SaasPyroModel):
585+
raise NotImplementedError("load_state_dict only works for SaasPyroModel")
586+
raw_mean = state_dict["mean_module.raw_constant"]
587+
num_mcmc_samples = len(raw_mean)
588+
dim = self.pyro_model.train_X.shape[-1]
589+
tkwargs = {"device": raw_mean.device, "dtype": raw_mean.dtype}
590+
# Load some dummy samples
591+
mcmc_samples = {
592+
"mean": torch.ones(num_mcmc_samples, **tkwargs),
593+
"lengthscale": torch.ones(num_mcmc_samples, dim, **tkwargs),
594+
"outputscale": torch.ones(num_mcmc_samples, **tkwargs),
595+
}
596+
if self.pyro_model.train_Yvar is None:
597+
mcmc_samples["noise"] = torch.ones(num_mcmc_samples, **tkwargs)
598+
(
599+
self.mean_module,
600+
self.covar_module,
601+
self.likelihood,
602+
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
603+
# Load the actual samples from the state dict
604+
super().load_state_dict(state_dict=state_dict, strict=strict)

0 commit comments

Comments
 (0)