Skip to content

Commit 4033488

Browse files
sdaultonfacebook-github-bot
authored andcommitted
update dispatch/storage for new Fully Bayesian models/refactor (facebook#3806)
Summary: X-link: pytorch/botorch#2857 see title Reviewed By: Balandat Differential Revision: D74824675
1 parent 7bc053a commit 4033488

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

ax/generators/torch/botorch_modular/utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929
)
3030
from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_mll
3131
from botorch.models import PairwiseLaplaceMarginalLogLikelihood
32-
from botorch.models.fully_bayesian import FullyBayesianSingleTaskGP
32+
from botorch.models.fully_bayesian import (
33+
AbstractFullyBayesianSingleTaskGP,
34+
FullyBayesianSingleTaskGP,
35+
)
3336
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
3437
from botorch.models.gp_regression import SingleTaskGP
3538
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
@@ -492,10 +495,11 @@ def _fit_botorch_model_gpytorch(
492495
fit_gpytorch_mll(mll)
493496

494497

495-
@fit_botorch_model.register(FullyBayesianSingleTaskGP)
496-
@fit_botorch_model.register(SaasFullyBayesianMultiTaskGP)
498+
@fit_botorch_model.register(
499+
(AbstractFullyBayesianSingleTaskGP, SaasFullyBayesianMultiTaskGP)
500+
)
497501
def _fit_botorch_model_fully_bayesian_nuts(
498-
model: FullyBayesianSingleTaskGP | SaasFullyBayesianMultiTaskGP,
502+
model: AbstractFullyBayesianSingleTaskGP | SaasFullyBayesianMultiTaskGP,
499503
mll_class: type[MarginalLogLikelihood],
500504
mll_options: dict[str, Any] | None = None,
501505
) -> None:

ax/storage/botorch_modular_registry.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,12 @@
5656
AnalyticExpectedUtilityOfBestOption,
5757
qExpectedUtilityOfBestOption,
5858
)
59-
from botorch.models import SaasFullyBayesianSingleTaskGP
6059
from botorch.models.contextual import LCEAGP
61-
from botorch.models.fully_bayesian import FullyBayesianLinearSingleTaskGP
60+
from botorch.models.fully_bayesian import (
61+
FullyBayesianLinearSingleTaskGP,
62+
FullyBayesianSingleTaskGP,
63+
SaasFullyBayesianSingleTaskGP,
64+
)
6265
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
6366

6467
# BoTorch `Model` imports
@@ -124,6 +127,7 @@
124127
SingleTaskMultiFidelityGP: "SingleTaskMultiFidelityGP",
125128
FullyBayesianLinearSingleTaskGP: "FullyBayesianLinearSingleTaskGP",
126129
SaasFullyBayesianSingleTaskGP: "SaasFullyBayesianSingleTaskGP",
130+
FullyBayesianSingleTaskGP: "FullyBayesianSingleTaskGP",
127131
SaasFullyBayesianMultiTaskGP: "SaasFullyBayesianMultiTaskGP",
128132
LCEAGP: "LCEAGP",
129133
}

0 commit comments

Comments
 (0)