Skip to content

Commit be3da75

Browse files
sdaultonfacebook-github-bot
authored andcommitted
update dispatch/storage for new Fully Bayesian models/refactor (facebook#3806)
Summary: X-link: pytorch/botorch#2857 see title Reviewed By: Balandat Differential Revision: D74824675
1 parent e2cc345 commit be3da75

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

ax/generators/torch/botorch_modular/utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@
3232
)
3333
from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_mll
3434
from botorch.models import PairwiseLaplaceMarginalLogLikelihood
35-
from botorch.models.fully_bayesian import FullyBayesianSingleTaskGP
35+
from botorch.models.fully_bayesian import (
36+
AbstractFullyBayesianSingleTaskGP,
37+
FullyBayesianSingleTaskGP,
38+
)
3639
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
3740
from botorch.models.gp_regression import SingleTaskGP
3841
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
@@ -532,10 +535,11 @@ def _fit_botorch_model_gpytorch(
532535
fit_gpytorch_mll(mll)
533536

534537

535-
@fit_botorch_model.register(FullyBayesianSingleTaskGP)
536-
@fit_botorch_model.register(SaasFullyBayesianMultiTaskGP)
538+
@fit_botorch_model.register(
539+
(AbstractFullyBayesianSingleTaskGP, SaasFullyBayesianMultiTaskGP)
540+
)
537541
def _fit_botorch_model_fully_bayesian_nuts(
538-
model: FullyBayesianSingleTaskGP | SaasFullyBayesianMultiTaskGP,
542+
model: AbstractFullyBayesianSingleTaskGP | SaasFullyBayesianMultiTaskGP,
539543
mll_class: type[MarginalLogLikelihood],
540544
mll_options: dict[str, Any] | None = None,
541545
) -> None:

ax/storage/botorch_modular_registry.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,12 @@
5656
AnalyticExpectedUtilityOfBestOption,
5757
qExpectedUtilityOfBestOption,
5858
)
59-
from botorch.models import SaasFullyBayesianSingleTaskGP
6059
from botorch.models.contextual import LCEAGP
61-
from botorch.models.fully_bayesian import FullyBayesianLinearSingleTaskGP
60+
from botorch.models.fully_bayesian import (
61+
FullyBayesianLinearSingleTaskGP,
62+
FullyBayesianSingleTaskGP,
63+
SaasFullyBayesianSingleTaskGP,
64+
)
6265
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
6366

6467
# BoTorch `Model` imports
@@ -124,6 +127,7 @@
124127
SingleTaskMultiFidelityGP: "SingleTaskMultiFidelityGP",
125128
FullyBayesianLinearSingleTaskGP: "FullyBayesianLinearSingleTaskGP",
126129
SaasFullyBayesianSingleTaskGP: "SaasFullyBayesianSingleTaskGP",
130+
FullyBayesianSingleTaskGP: "FullyBayesianSingleTaskGP",
127131
SaasFullyBayesianMultiTaskGP: "SaasFullyBayesianMultiTaskGP",
128132
LCEAGP: "LCEAGP",
129133
}

0 commit comments

Comments
 (0)