Skip to content

Commit d247a33

Browse files
sdaultonfacebook-github-bot
authored andcommitted
fix bug where input transforms are not applied in fully Bayesian models in train mode (#2859)
Summary: Pull Request resolved: #2859 This fixes a bug where input transforms were not applied to fully Bayesian GPs in training mode. This only affects computing MLL, AIC, and BIC (which previously where computing without applying normalization/warping) for fully Bayesian GPs. We don't evaluate fully Bayesian models in `train` mode. Reviewed By: saitcakmak Differential Revision: D74827275 fbshipit-source-id: eac39e2dc7533581aff3206b0eaf642e01fc7e8c
1 parent c24e534 commit d247a33

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

botorch/models/fully_bayesian.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,8 @@ def forward(self, X: Tensor) -> MultivariateNormal:
794794
rest of this method will not run.
795795
"""
796796
self._check_if_fitted()
797+
if self.training:
798+
X = self.transform_inputs(X=X)
797799
mean_x = self.mean_module(X)
798800
covar_x = self.covar_module(X)
799801
return MultivariateNormal(mean_x, covar_x)

test/models/test_fully_bayesian.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,36 @@ def test_deprecated_posterior(self) -> None:
956956
posterior = FullyBayesianPosterior(distribution=mvn)
957957
self.assertIsInstance(posterior, GaussianMixturePosterior)
958958

959+
def test_predict_in_train_mode(self) -> None:
960+
torch.manual_seed(16)
961+
for infer_noise, dtype in itertools.product(
962+
[True, False], [torch.float, torch.double]
963+
):
964+
tkwargs = {"device": self.device, "dtype": dtype}
965+
train_X, train_Y, train_Yvar, _ = self._get_data_and_model(
966+
infer_noise=infer_noise, **tkwargs
967+
)
968+
# Fit a model and check that the hyperparameters have the correct shape
969+
model = self.model_cls(
970+
train_X=train_X,
971+
train_Y=train_Y,
972+
train_Yvar=train_Yvar,
973+
input_transform=Normalize(d=train_X.shape[-1]),
974+
outcome_transform=Standardize(m=1),
975+
**self.model_kwargs,
976+
)
977+
fit_fully_bayesian_model_nuts(
978+
model, warmup_steps=8, num_samples=5, thinning=2, disable_progbar=True
979+
)
980+
# check that input transforms are called when calling forward in train mode
981+
model.train(reset=False)
982+
with mock.patch.object(
983+
model.input_transform, "forward", wraps=model.input_transform.forward
984+
) as mock_input_tf:
985+
with torch.no_grad():
986+
model(*model.train_inputs)
987+
mock_input_tf.assert_called_once()
988+
959989

960990
class TestFullyBayesianSingleTaskGP(TestSaasFullyBayesianSingleTaskGP):
961991
model_cls: type[FullyBayesianSingleTaskGP] = FullyBayesianSingleTaskGP

0 commit comments

Comments
 (0)