Skip to content

Commit e9ce11f

Browse files
SaiAakashfacebook-github-bot
authored andcommitted
Added posterior_transform to posterior method in ApproximateGPyTorchModel (#2531)
Summary: ## Motivation This PR fixes #2530. Adds a new posterior_transform parameter to the posterior method of ApproximateGPyTorchModel. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: #2531 Test Plan: I was able to generate candidate points with ExpectedImprovement acquisition function with a SingleTaskVariationalGP that was trained on 2 output columns. ## Related PRs NA Reviewed By: mgarrard Differential Revision: D62652630 Pulled By: Balandat fbshipit-source-id: 6870c8a1f47454e70951e7e7eb420cd05a2fb246
1 parent 6ebfa82 commit e9ce11f

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

botorch/models/approximate_gp.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@
3131

3232
import copy
3333
import warnings
34-
3534
from typing import Optional, Union
3635

3736
import torch
37+
from botorch.acquisition.objective import PosteriorTransform
3838
from botorch.exceptions.warnings import UserInputWarning
3939
from botorch.models.gpytorch import GPyTorchModel
4040
from botorch.models.transforms.input import InputTransform
@@ -146,8 +146,16 @@ def train(self, mode: bool = True) -> Self:
146146
return Module.train(self, mode=mode)
147147

148148
def posterior(
149-
self, X, output_indices=None, observation_noise=False, *args, **kwargs
149+
self,
150+
X,
151+
output_indices: Optional[list[int]] = None,
152+
observation_noise: bool = False,
153+
posterior_transform: Optional[PosteriorTransform] = None,
150154
) -> GPyTorchPosterior:
155+
if output_indices is not None:
156+
raise NotImplementedError( # pragma: no cover
157+
f"{self.__class__.__name__}.posterior does not support output indices."
158+
)
151159
self.eval() # make sure model is in eval mode
152160

153161
# input transforms are applied at `posterior` in `eval` mode, and at
@@ -161,11 +169,13 @@ def posterior(
161169
X = X.unsqueeze(-3).repeat(*[1] * (X_ndim - 2), self.num_outputs, 1, 1)
162170
dist = self.model(X)
163171
if observation_noise:
164-
dist = self.likelihood(dist, *args, **kwargs)
172+
dist = self.likelihood(dist)
165173

166174
posterior = GPyTorchPosterior(distribution=dist)
167175
if hasattr(self, "outcome_transform"):
168176
posterior = self.outcome_transform.untransform_posterior(posterior)
177+
if posterior_transform is not None:
178+
posterior = posterior_transform(posterior)
169179
return posterior
170180

171181
def forward(self, X) -> MultivariateNormal:

test/models/test_approximate_gp.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import warnings
99

1010
import torch
11+
from botorch.acquisition.objective import ScalarizedPosteriorTransform
1112
from botorch.exceptions.warnings import UserInputWarning
1213
from botorch.fit import fit_gpytorch_mll
1314
from botorch.models.approximate_gp import (
@@ -103,6 +104,16 @@ def test_posterior(self):
103104
# test batch_shape property
104105
self.assertEqual(model.batch_shape, tx.shape[:-2])
105106

107+
# Test that checks if posterior_transform is correctly applied
108+
[tx1, ty1, test1] = all_tests["non_batched_mo"]
109+
model1 = SingleTaskVariationalGP(tx1, ty1, inducing_points=tx1)
110+
posterior_transform = ScalarizedPosteriorTransform(
111+
weights=torch.tensor([1.0, 1.0], device=self.device)
112+
)
113+
posterior1 = model1.posterior(test1, posterior_transform=posterior_transform)
114+
self.assertIsInstance(posterior1, GPyTorchPosterior)
115+
self.assertEqual(posterior1.mean.shape[1], 1)
116+
106117
def test_variational_setUp(self):
107118
for dtype in [torch.float, torch.double]:
108119
train_X = torch.rand(10, 1, device=self.device, dtype=dtype)

0 commit comments

Comments
 (0)