|
12 | 12 | from botorch.acquisition.objective import ScalarizedPosteriorTransform
|
13 | 13 | from botorch.exceptions.errors import BotorchTensorDimensionError, UnsupportedError
|
14 | 14 | from botorch.exceptions.warnings import OptimizationWarning
|
15 |
| -from botorch.fit import fit_gpytorch_mll |
| 15 | +from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_mll |
| 16 | +from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP |
16 | 17 | from botorch.models.gp_regression import SingleTaskGP
|
17 | 18 | from botorch.models.model_list_gp_regression import ModelListGP
|
18 | 19 | from botorch.models.multitask import MultiTaskGP
|
@@ -733,3 +734,29 @@ def test_fantasize_with_outcome_transform_fixed_noise(self) -> None:
|
733 | 734 | self.assertTrue(
|
734 | 735 | torch.equal(fm_i.train_inputs[0][0][-1], X[1 - i])
|
735 | 736 | )
|
| 737 | + |
| 738 | + def test_with_different_batch_shapes(self) -> None: |
| 739 | + # Tests that we can mix single task and SAAS models together. |
| 740 | + tkwargs = {"device": self.device, "dtype": torch.double} |
| 741 | + m1 = SaasFullyBayesianSingleTaskGP( |
| 742 | + train_X=torch.rand(10, 2, **tkwargs), train_Y=torch.rand(10, 1, **tkwargs) |
| 743 | + ) |
| 744 | + fit_fully_bayesian_model_nuts(m1, warmup_steps=0, num_samples=8, thinning=1) |
| 745 | + m2 = SingleTaskGP( |
| 746 | + train_X=torch.rand(10, 2, **tkwargs), train_Y=torch.rand(10, 1, **tkwargs) |
| 747 | + ) |
| 748 | + m = ModelListGP(m1, m2) |
| 749 | + with self.assertWarnsRegex(UserWarning, "Component models of"): |
| 750 | + self.assertEqual(m.batch_shape, torch.Size([8])) |
| 751 | + # Non-batched evaluation. |
| 752 | + with self.assertWarnsRegex(UserWarning, "Component models of"): |
| 753 | + post = m.posterior(torch.rand(1, 2, **tkwargs)) |
| 754 | + self.assertEqual(post.batch_shape, torch.Size([8])) |
| 755 | + self.assertEqual(post.rsample(torch.Size([2])).shape, torch.Size([2, 8, 1, 2])) |
| 756 | + # Batched evaluation. |
| 757 | + with self.assertWarnsRegex(UserWarning, "Component models of"): |
| 758 | + post = m.posterior(torch.rand(5, 1, 2, **tkwargs)) |
| 759 | + self.assertEqual(post.batch_shape, torch.Size([5, 8])) |
| 760 | + self.assertEqual( |
| 761 | + post.rsample(torch.Size([2])).shape, torch.Size([2, 5, 8, 1, 2]) |
| 762 | + ) |
0 commit comments