|
31 | 31 | """
|
32 | 32 |
|
33 | 33 | import math
|
34 |
| -from abc import abstractmethod |
| 34 | +from abc import ABC, abstractmethod |
35 | 35 | from collections.abc import Mapping
|
36 | 36 | from typing import Any
|
37 | 37 |
|
@@ -311,14 +311,13 @@ def load_mcmc_samples(
|
311 | 311 | return mean_module, covar_module, likelihood
|
312 | 312 |
|
313 | 313 |
|
314 |
| -class SaasFullyBayesianSingleTaskGP(ExactGP, BatchedMultiOutputGPyTorchModel): |
315 |
| - r"""A fully Bayesian single-task GP model with the SAAS prior. |
| 314 | +class FullyBayesianSingleTaskGP(ExactGP, BatchedMultiOutputGPyTorchModel, ABC): |
| 315 | + r"""An abstract fully Bayesian single-task GP model. |
316 | 316 |
|
317 | 317 | This model assumes that the inputs have been normalized to [0, 1]^d and that
|
318 | 318 | the output has been standardized to have zero mean and unit variance. You can
|
319 | 319 | either normalize and standardize the data before constructing the model or use
|
320 |
| - an `input_transform` and `outcome_transform`. The SAAS model [Eriksson2021saasbo]_ |
321 |
| - with a Matern-5/2 kernel is used by default. |
| 320 | + an `input_transform` and `outcome_transform`. |
322 | 321 |
|
323 | 322 | You are expected to use `fit_fully_bayesian_model_nuts` to fit this model as it
|
324 | 323 | isn't compatible with `fit_gpytorch_mll`.
|
@@ -412,17 +411,9 @@ def _check_if_fitted(self):
|
412 | 411 | )
|
413 | 412 |
|
414 | 413 | @property
|
415 |
| - def median_lengthscale(self) -> Tensor: |
416 |
| - r"""Median lengthscales across the MCMC samples.""" |
417 |
| - self._check_if_fitted() |
418 |
| - lengthscale = self.covar_module.base_kernel.lengthscale.clone() |
419 |
| - return lengthscale.median(0).values.squeeze(0) |
420 |
| - |
421 |
| - @property |
| 414 | + @abstractmethod |
422 | 415 | def num_mcmc_samples(self) -> int:
|
423 | 416 | r"""Number of MCMC samples in the model."""
|
424 |
| - self._check_if_fitted() |
425 |
| - return len(self.covar_module.outputscale) |
426 | 417 |
|
427 | 418 | @property
|
428 | 419 | def batch_shape(self) -> torch.Size:
|
@@ -459,41 +450,6 @@ def load_mcmc_samples(self, mcmc_samples: dict[str, Tensor]) -> None:
|
459 | 450 | self.likelihood,
|
460 | 451 | ) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
|
461 | 452 |
|
462 |
| - def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): |
463 |
| - r"""Custom logic for loading the state dict. |
464 |
| -
|
465 |
| - The standard approach of calling `load_state_dict` currently doesn't play well |
466 |
| - with the `SaasFullyBayesianSingleTaskGP` since the `mean_module`, `covar_module` |
467 |
| - and `likelihood` aren't initialized until the model has been fitted. The reason |
468 |
| - for this is that we don't know the number of MCMC samples until NUTS is called. |
469 |
| - Given the state dict, we can initialize a new model with some dummy samples and |
470 |
| - then load the state dict into this model. This currently only works for a |
471 |
| - `SaasPyroModel` and supporting more Pyro models likely requires moving the model |
472 |
| - construction logic into the Pyro model itself. |
473 |
| - """ |
474 |
| - |
475 |
| - if not isinstance(self.pyro_model, SaasPyroModel): |
476 |
| - raise NotImplementedError("load_state_dict only works for SaasPyroModel") |
477 |
| - raw_mean = state_dict["mean_module.raw_constant"] |
478 |
| - num_mcmc_samples = len(raw_mean) |
479 |
| - dim = self.pyro_model.train_X.shape[-1] |
480 |
| - tkwargs = {"device": raw_mean.device, "dtype": raw_mean.dtype} |
481 |
| - # Load some dummy samples |
482 |
| - mcmc_samples = { |
483 |
| - "mean": torch.ones(num_mcmc_samples, **tkwargs), |
484 |
| - "lengthscale": torch.ones(num_mcmc_samples, dim, **tkwargs), |
485 |
| - "outputscale": torch.ones(num_mcmc_samples, **tkwargs), |
486 |
| - } |
487 |
| - if self.pyro_model.train_Yvar is None: |
488 |
| - mcmc_samples["noise"] = torch.ones(num_mcmc_samples, **tkwargs) |
489 |
| - ( |
490 |
| - self.mean_module, |
491 |
| - self.covar_module, |
492 |
| - self.likelihood, |
493 |
| - ) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples) |
494 |
| - # Load the actual samples from the state dict |
495 |
| - super().load_state_dict(state_dict=state_dict, strict=strict) |
496 |
| - |
497 | 453 | def forward(self, X: Tensor) -> MultivariateNormal:
|
498 | 454 | """
|
499 | 455 | Unlike in other classes' `forward` methods, there is no `if self.training`
|
@@ -579,3 +535,70 @@ def condition_on_observations(
|
579 | 535 | X = X.repeat(*(Y.shape[:-2] + (1, 1)))
|
580 | 536 |
|
581 | 537 | return super().condition_on_observations(X, Y, **kwargs)
|
| 538 | + |
| 539 | + |
| 540 | +class SaasFullyBayesianSingleTaskGP(FullyBayesianSingleTaskGP): |
| 541 | + r"""A fully Bayesian single-task GP model with the SAAS prior. |
| 542 | +
|
| 543 | + This model assumes that the inputs have been normalized to [0, 1]^d and that |
| 544 | + the output has been standardized to have zero mean and unit variance. You can |
| 545 | + either normalize and standardize the data before constructing the model or use |
| 546 | + an `input_transform` and `outcome_transform`. The SAAS model [Eriksson2021saasbo]_ |
| 547 | + with a Matern-5/2 kernel is used by default. |
| 548 | +
|
| 549 | + You are expected to use `fit_fully_bayesian_model_nuts` to fit this model as it |
| 550 | + isn't compatible with `fit_gpytorch_mll`. |
| 551 | +
|
| 552 | + Example: |
| 553 | + >>> saas_gp = SaasFullyBayesianSingleTaskGP(train_X, train_Y) |
| 554 | + >>> fit_fully_bayesian_model_nuts(saas_gp) |
| 555 | + >>> posterior = saas_gp.posterior(test_X) |
| 556 | + """ |
| 557 | + |
| 558 | + @property |
| 559 | + def num_mcmc_samples(self) -> int: |
| 560 | + r"""Number of MCMC samples in the model.""" |
| 561 | + self._check_if_fitted() |
| 562 | + return len(self.covar_module.outputscale) |
| 563 | + |
| 564 | + @property |
| 565 | + def median_lengthscale(self) -> Tensor: |
| 566 | + r"""Median lengthscales across the MCMC samples.""" |
| 567 | + self._check_if_fitted() |
| 568 | + lengthscale = self.covar_module.base_kernel.lengthscale.clone() |
| 569 | + return lengthscale.median(0).values.squeeze(0) |
| 570 | + |
| 571 | + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): |
| 572 | + r"""Custom logic for loading the state dict. |
| 573 | +
|
| 574 | + The standard approach of calling `load_state_dict` currently doesn't play well |
| 575 | + with the `SaasFullyBayesianSingleTaskGP` since the `mean_module`, `covar_module` |
| 576 | + and `likelihood` aren't initialized until the model has been fitted. The reason |
| 577 | + for this is that we don't know the number of MCMC samples until NUTS is called. |
| 578 | + Given the state dict, we can initialize a new model with some dummy samples and |
| 579 | + then load the state dict into this model. This currently only works for a |
| 580 | + `SaasPyroModel` and supporting more Pyro models likely requires moving the model |
| 581 | + construction logic into the Pyro model itself. |
| 582 | + """ |
| 583 | + |
| 584 | + if not isinstance(self.pyro_model, SaasPyroModel): |
| 585 | + raise NotImplementedError("load_state_dict only works for SaasPyroModel") |
| 586 | + raw_mean = state_dict["mean_module.raw_constant"] |
| 587 | + num_mcmc_samples = len(raw_mean) |
| 588 | + dim = self.pyro_model.train_X.shape[-1] |
| 589 | + tkwargs = {"device": raw_mean.device, "dtype": raw_mean.dtype} |
| 590 | + # Load some dummy samples |
| 591 | + mcmc_samples = { |
| 592 | + "mean": torch.ones(num_mcmc_samples, **tkwargs), |
| 593 | + "lengthscale": torch.ones(num_mcmc_samples, dim, **tkwargs), |
| 594 | + "outputscale": torch.ones(num_mcmc_samples, **tkwargs), |
| 595 | + } |
| 596 | + if self.pyro_model.train_Yvar is None: |
| 597 | + mcmc_samples["noise"] = torch.ones(num_mcmc_samples, **tkwargs) |
| 598 | + ( |
| 599 | + self.mean_module, |
| 600 | + self.covar_module, |
| 601 | + self.likelihood, |
| 602 | + ) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples) |
| 603 | + # Load the actual samples from the state dict |
| 604 | + super().load_state_dict(state_dict=state_dict, strict=strict) |
0 commit comments