Skip to content

Commit 6b8a742

Browse files
sdaultonfacebook-github-bot
authored andcommitted
pass X to OutcomeTransform (pytorch#2663)
Summary: This enables using outcome transforms with behavior that depends on X. For example, this enables implementing a stratified standardize transform, where the the standardization is performing for distinct values of an input dimension. Reviewed By: esantorella Differential Revision: D67724473
1 parent a3d47f4 commit 6b8a742

13 files changed

+102
-48
lines changed

botorch/acquisition/analytic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1116,7 +1116,7 @@ def _get_noiseless_fantasy_model(
11161116
# Not transforming Yvar because 1e-7 is already close to 0 and it is a
11171117
# relative, not absolute, value.
11181118
Y_fantasized, _ = outcome_transform(
1119-
Y_fantasized.unsqueeze(-1), Yvar.unsqueeze(-1)
1119+
Y_fantasized.unsqueeze(-1), Yvar.unsqueeze(-1), X=batch_X_observed
11201120
)
11211121
Y_fantasized = Y_fantasized.squeeze(-1)
11221122
input_transform = getattr(model, "input_transform", None)

botorch/models/approximate_gp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def posterior(
172172

173173
posterior = GPyTorchPosterior(distribution=dist)
174174
if hasattr(self, "outcome_transform"):
175-
posterior = self.outcome_transform.untransform_posterior(posterior)
175+
posterior = self.outcome_transform.untransform_posterior(posterior, X=X)
176176
if posterior_transform is not None:
177177
posterior = posterior_transform(posterior)
178178
return posterior
@@ -397,7 +397,7 @@ def __init__(
397397
UserInputWarning,
398398
stacklevel=3,
399399
)
400-
train_Y, _ = outcome_transform(train_Y)
400+
train_Y, _ = outcome_transform(train_Y, X=transformed_X)
401401
self._validate_tensor_args(X=transformed_X, Y=train_Y)
402402
validate_input_scaling(train_X=transformed_X, train_Y=train_Y)
403403
if train_Y.shape[-1] != num_outputs:

botorch/models/ensemble.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def posterior(
7979
# `posterior` (as is done in GP models). This is more general since it works
8080
# even if the transform doesn't support `untransform_posterior`.
8181
if hasattr(self, "outcome_transform"):
82-
values, _ = self.outcome_transform.untransform(values)
82+
values, _ = self.outcome_transform.untransform(values, X=X)
8383
if output_indices is not None:
8484
values = values[..., output_indices]
8585
posterior = EnsemblePosterior(values=values)

botorch/models/fully_bayesian.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,9 @@ def __init__(
373373
X=train_X, input_transform=input_transform
374374
)
375375
if outcome_transform is not None:
376-
train_Y, train_Yvar = outcome_transform(train_Y, train_Yvar)
376+
train_Y, train_Yvar = outcome_transform(
377+
Y=train_Y, Yvar=train_Yvar, X=transformed_X
378+
)
377379
self._validate_tensor_args(X=transformed_X, Y=train_Y)
378380
validate_input_scaling(
379381
train_X=transformed_X, train_Y=train_Y, train_Yvar=train_Yvar

botorch/models/fully_bayesian_multitask.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,9 @@ def __init__(
242242
)
243243
if outcome_transform is not None:
244244
outcome_transform.train() # Ensure we learn parameters here on init
245-
train_Y, train_Yvar = outcome_transform(train_Y, train_Yvar)
245+
train_Y, train_Yvar = outcome_transform(
246+
Y=train_Y, Yvar=train_Yvar, X=transformed_X
247+
)
246248
if train_Yvar is not None: # Clamp after transforming
247249
train_Yvar = train_Yvar.clamp(MIN_INFERRED_NOISE_LEVEL)
248250

botorch/models/gp_regression.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,9 @@ def __init__(
160160
X=train_X, input_transform=input_transform
161161
)
162162
if outcome_transform is not None:
163-
train_Y, train_Yvar = outcome_transform(train_Y, train_Yvar)
163+
train_Y, train_Yvar = outcome_transform(
164+
Y=train_Y, Yvar=train_Yvar, X=transformed_X
165+
)
164166
# Validate again after applying the transforms
165167
self._validate_tensor_args(X=transformed_X, Y=train_Y, Yvar=train_Yvar)
166168
ignore_X_dims = getattr(self, "_ignore_X_dims_scaling_check", None)

botorch/models/gpytorch.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def posterior(
198198
mvn = self.likelihood(mvn, X)
199199
posterior = GPyTorchPosterior(distribution=mvn)
200200
if hasattr(self, "outcome_transform"):
201-
posterior = self.outcome_transform.untransform_posterior(posterior)
201+
posterior = self.outcome_transform.untransform_posterior(posterior, X=X)
202202
if posterior_transform is not None:
203203
return posterior_transform(posterior)
204204
return posterior
@@ -244,7 +244,7 @@ def condition_on_observations(
244244
# (unless we've already trasnformed if BatchedMultiOutputGPyTorchModel)
245245
if not isinstance(self, BatchedMultiOutputGPyTorchModel):
246246
# `noise` is assumed to already be outcome-transformed.
247-
Y, _ = self.outcome_transform(Y=Y, Yvar=Yvar)
247+
Y, _ = self.outcome_transform(Y=Y, Yvar=Yvar, X=X)
248248
# Validate using strict=False, since we cannot tell if Y has an explicit
249249
# output dimension. Do not check shapes when fantasizing as they are
250250
# not expected to match.
@@ -467,7 +467,7 @@ def posterior(
467467

468468
posterior = GPyTorchPosterior(distribution=mvn)
469469
if hasattr(self, "outcome_transform"):
470-
posterior = self.outcome_transform.untransform_posterior(posterior)
470+
posterior = self.outcome_transform.untransform_posterior(posterior, X=X)
471471
if posterior_transform is not None:
472472
return posterior_transform(posterior)
473473
return posterior
@@ -511,7 +511,7 @@ def condition_on_observations(
511511
if hasattr(self, "outcome_transform"):
512512
# We need to apply transforms before shifting batch indices around.
513513
# `noise` is assumed to already be outcome-transformed.
514-
Y, _ = self.outcome_transform(Y)
514+
Y, _ = self.outcome_transform(Y, X=X)
515515
# Do not check shapes when fantasizing as they are not expected to match.
516516
if fantasize_flag.off():
517517
self._validate_tensor_args(X=X, Y=Y, Yvar=noise, strict=False)
@@ -924,7 +924,7 @@ def posterior(
924924
)
925925
posterior = GPyTorchPosterior(distribution=mtmvn)
926926
if hasattr(self, "outcome_transform"):
927-
posterior = self.outcome_transform.untransform_posterior(posterior)
927+
posterior = self.outcome_transform.untransform_posterior(posterior, X=X)
928928
if posterior_transform is not None:
929929
return posterior_transform(posterior)
930930
return posterior

botorch/models/higher_order_gp.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _return_to_output_shape(self, tsr: Tensor) -> Tensor:
9191
return out
9292

9393
def forward(
94-
self, Y: Tensor, Yvar: Tensor | None = None
94+
self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None
9595
) -> tuple[Tensor, Tensor | None]:
9696
Y = self._squeeze_to_single_output(Y)
9797
if Yvar is not None:
@@ -107,21 +107,21 @@ def forward(
107107
return Y_out, Yvar_out
108108

109109
def untransform(
110-
self, Y: Tensor, Yvar: Tensor | None = None
110+
self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None
111111
) -> tuple[Tensor, Tensor | None]:
112112
Y = self._squeeze_to_single_output(Y)
113113
if Yvar is not None:
114114
Yvar = self._squeeze_to_single_output(Yvar)
115115

116-
Y, Yvar = super().untransform(Y, Yvar)
116+
Y, Yvar = super().untransform(Y=Y, Yvar=Yvar, X=X)
117117

118118
Y = self._return_to_output_shape(Y)
119119
if Yvar is not None:
120120
Yvar = self._return_to_output_shape(Yvar)
121121
return Y, Yvar
122122

123123
def untransform_posterior(
124-
self, posterior: HigherOrderGPPosterior
124+
self, posterior: HigherOrderGPPosterior, X: Tensor | None = None
125125
) -> TransformedPosterior:
126126
# TODO: return a HigherOrderGPPosterior once rescaling constant
127127
# muls * LinearOperators won't force a dense decomposition rather than a
@@ -227,7 +227,7 @@ def __init__(
227227
output_shape=train_Y.shape[-num_output_dims:],
228228
batch_shape=batch_shape,
229229
)
230-
train_Y, _ = outcome_transform(train_Y)
230+
train_Y, _ = outcome_transform(train_Y, X=train_X)
231231

232232
self._aug_batch_shape = batch_shape
233233
self._num_dimensions = num_output_dims + 1
@@ -416,7 +416,7 @@ def condition_on_observations(
416416
"""
417417
if hasattr(self, "outcome_transform"):
418418
# we need to apply transforms before shifting batch indices around
419-
Y, noise = self.outcome_transform(Y=Y, Yvar=noise)
419+
Y, noise = self.outcome_transform(Y=Y, Yvar=noise, X=X)
420420
# Do not check shapes when fantasizing as they are not expected to match.
421421
if fantasize_flag.off():
422422
self._validate_tensor_args(X=X, Y=Y, Yvar=noise, strict=False)
@@ -540,7 +540,7 @@ def posterior(
540540
num_outputs=self._num_outputs,
541541
)
542542
if hasattr(self, "outcome_transform"):
543-
posterior = self.outcome_transform.untransform_posterior(posterior)
543+
posterior = self.outcome_transform.untransform_posterior(posterior, X=X)
544544
return posterior
545545

546546
def make_posterior_variances(

botorch/models/latent_kronecker_gp.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(
8181
self._use_min = use_min
8282

8383
def forward(
84-
self, Y: Tensor, Yvar: Tensor | None = None
84+
self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None
8585
) -> tuple[Tensor, Tensor | None]:
8686
r"""Standardize outcomes.
8787
@@ -93,6 +93,7 @@ def forward(
9393
Y: A `batch_shape x n x m`-dim tensor of training targets.
9494
Yvar: A `batch_shape x n x m`-dim tensor of observation noises
9595
associated with the training targets (if applicable).
96+
X: A `batch_shape x n x d`-dim tensor of training inputs (if applicable).
9697
9798
Returns:
9899
A two-tuple with the transformed outcomes:
@@ -240,7 +241,9 @@ def __init__(
240241
outcome_transform = MinMaxStandardize(batch_shape=batch_shape)
241242
if outcome_transform is not None:
242243
# transform outputs once and keep the results
243-
train_Y = outcome_transform(train_Y.unsqueeze(-1))[0].squeeze(-1)
244+
train_Y = outcome_transform(train_Y.unsqueeze(-1), X=transformed_X)[
245+
0
246+
].squeeze(-1)
244247

245248
ExactGP.__init__(
246249
self,
@@ -506,7 +509,7 @@ def _rsample_from_base_samples(
506509
)
507510
# samples.shape = (*sample_shape, *broadcast_shape, n_test_x, n_t)
508511
if hasattr(self, "outcome_transform") and self.outcome_transform is not None:
509-
samples, _ = self.outcome_transform.untransform(samples)
512+
samples, _ = self.outcome_transform.untransform(samples, X=X)
510513
return samples
511514

512515
def condition_on_observations(

botorch/models/model_list_gp_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def condition_on_observations(
117117
else:
118118
noise_i = torch.cat([noise[..., k] for k in range(i, j)], dim=-1)
119119
if hasattr(model, "outcome_transform"):
120-
y_i, noise_i = model.outcome_transform(y_i, noise_i)
120+
y_i, noise_i = model.outcome_transform(y_i, noise_i, X=X_i)
121121
if noise_i is not None:
122122
noise_i = noise_i.squeeze(0)
123123
targets.append(y_i)

botorch/models/multitask.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,9 @@ def __init__(
219219
if outcome_transform == DEFAULT:
220220
outcome_transform = Standardize(m=1, batch_shape=train_X.shape[:-2])
221221
if outcome_transform is not None:
222-
train_Y, train_Yvar = outcome_transform(Y=train_Y, Yvar=train_Yvar)
222+
train_Y, train_Yvar = outcome_transform(
223+
Y=train_Y, Yvar=train_Yvar, X=transformed_X
224+
)
223225

224226
# squeeze output dim
225227
train_Y = train_Y.squeeze(-1)
@@ -464,7 +466,7 @@ def __init__(
464466
X=train_X, input_transform=input_transform
465467
)
466468
if outcome_transform is not None:
467-
train_Y, _ = outcome_transform(train_Y)
469+
train_Y, _ = outcome_transform(train_Y, X=transformed_X)
468470

469471
self._validate_tensor_args(X=transformed_X, Y=train_Y)
470472
self._num_outputs = train_Y.shape[-1]
@@ -772,7 +774,7 @@ def posterior(
772774
)
773775

774776
if hasattr(self, "outcome_transform"):
775-
posterior = self.outcome_transform.untransform_posterior(posterior)
777+
posterior = self.outcome_transform.untransform_posterior(posterior, X=X)
776778
return posterior
777779

778780
def train(self, val=True, *args, **kwargs):

0 commit comments

Comments
 (0)