Skip to content

Commit 831ea5d

Browse files
sdaultonfacebook-github-bot
authored andcommitted
StratifiedStandardize OutcomeTransform (#2671)
Summary: Pull Request resolved: #2671 see title. This allows applying stratified standardization at the model level, which will enable selecting whether to use a Single-task or multi-task model in Ax while using the appropriate transform. I.e. One could specify ModelConfigs that use 1) `SingleTaskGP` + `Standardize`, 2) `MultiTaskGP` + `StratifiedStandardize`. Reviewed By: saitcakmak Differential Revision: D67728920 fbshipit-source-id: ad6ee2bbed3e484288e94dcfb7b1555fbd4395e4
1 parent be8ec7b commit 831ea5d

File tree

5 files changed

+407
-58
lines changed

5 files changed

+407
-58
lines changed

botorch/models/multitask.py

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from botorch.models.model import FantasizeMixin
4040
from botorch.models.transforms.input import InputTransform
4141
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
42+
from botorch.models.utils.assorted import get_task_value_remapping
4243
from botorch.models.utils.gpytorch_modules import (
4344
get_covar_module_with_dim_scaled_prior,
4445
get_gaussian_likelihood_with_lognormal_prior,
@@ -82,40 +83,6 @@
8283
from torch import Tensor
8384

8485

85-
def get_task_value_remapping(task_values: Tensor, dtype: torch.dtype) -> Tensor | None:
86-
"""Construct an mapping of discrete task values to contiguous int-valued floats.
87-
88-
Args:
89-
task_values: A sorted long-valued tensor of task values.
90-
dtype: The dtype of the model inputs (e.g. `X`), which the new
91-
task values should have mapped to (e.g. float, double).
92-
93-
Returns:
94-
A tensor of shape `task_values.max() + 1` that maps task values
95-
to new task values. The indexing operation `mapper[task_value]`
96-
will produce a tensor of new task values, of the same shape as
97-
the original. The elements of the `mapper` tensor that do not
98-
appear in the original `task_values` are mapped to `nan`. The
99-
return value will be `None`, when the task values are contiguous
100-
integers starting from zero.
101-
"""
102-
task_range = torch.arange(
103-
len(task_values), dtype=task_values.dtype, device=task_values.device
104-
)
105-
mapper = None
106-
if not torch.equal(task_values, task_range):
107-
# Create a tensor that maps task values to new task values.
108-
# The number of tasks should be small, so this should be quite efficient.
109-
mapper = torch.full(
110-
(int(task_values.max().item()) + 1,),
111-
float("nan"),
112-
dtype=dtype,
113-
device=task_values.device,
114-
)
115-
mapper[task_values] = task_range.to(dtype=dtype)
116-
return mapper
117-
118-
11986
class MultiTaskGP(ExactGP, MultiTaskGPyTorchModel, FantasizeMixin):
12087
r"""Multi-Task exact GP model using an ICM (intrinsic co-regionalization model)
12188
kernel. See [Bonilla2007MTGP]_ and [Swersky2013MTBO]_ for a reference on the

botorch/models/transforms/outcome.py

Lines changed: 250 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@
2727

2828
import torch
2929
from botorch.models.transforms.utils import (
30+
nanstd,
3031
norm_to_lognorm_mean,
3132
norm_to_lognorm_variance,
3233
)
34+
from botorch.models.utils.assorted import get_task_value_remapping
3335
from botorch.posteriors import GPyTorchPosterior, Posterior, TransformedPosterior
3436
from botorch.utils.transforms import normalize_indices
3537
from linear_operator.operators import CholLinearOperator, DiagLinearOperator
@@ -259,6 +261,46 @@ def __init__(
259261
self._batch_shape = batch_shape
260262
self._min_stdv = min_stdv
261263

264+
def _get_per_input_means_stdvs(
265+
self, X: Tensor, include_stdvs_sq: bool
266+
) -> tuple[Tensor, Tensor, Tensor | None]:
267+
r"""Get per-input means and stdvs.
268+
269+
Args:
270+
X: A `batch_shape x n x d`-dim tensor of input parameters.
271+
include_stdvs_sq: Whether to include the stdvs squared.
272+
This parameter is not used by this method
273+
274+
Returns:
275+
A three-tuple with the means and stdvs:
276+
277+
- The per-input means.
278+
- The per-input stdvs.
279+
- The per-input stdvs squared.
280+
"""
281+
return self.means, self.stdvs, self._stdvs_sq
282+
283+
def _validate_training_inputs(self, Y: Tensor, Yvar: Tensor | None = None) -> None:
284+
"""Validate training inputs.
285+
286+
Args:
287+
Y: A `batch_shape x n x m`-dim tensor of training targets.
288+
Yvar: A `batch_shape x n x m`-dim tensor of observation noises.
289+
"""
290+
if Y.shape[:-2] != self._batch_shape:
291+
raise RuntimeError(
292+
f"Expected Y.shape[:-2] to be {self._batch_shape}, matching "
293+
f"the `batch_shape` argument to `{self.__class__.__name__}`, but got "
294+
f"Y.shape[:-2]={Y.shape[:-2]}."
295+
)
296+
elif Y.shape[-2] < 1:
297+
raise ValueError(f"Can't standardize with no observations. {Y.shape=}.")
298+
elif Y.size(-1) != self._m:
299+
raise RuntimeError(
300+
f"Wrong output dimension. Y.size(-1) is {Y.size(-1)}; expected "
301+
f"{self._m}."
302+
)
303+
262304
def forward(
263305
self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None
264306
) -> tuple[Tensor, Tensor | None]:
@@ -283,21 +325,8 @@ def forward(
283325
- The transformed observation noise (if applicable).
284326
"""
285327
if self.training:
286-
if Y.shape[:-2] != self._batch_shape:
287-
raise RuntimeError(
288-
f"Expected Y.shape[:-2] to be {self._batch_shape}, matching "
289-
"the `batch_shape` argument to `Standardize`, but got "
290-
f"Y.shape[:-2]={Y.shape[:-2]}."
291-
)
292-
if Y.size(-1) != self._m:
293-
raise RuntimeError(
294-
f"Wrong output dimension. Y.size(-1) is {Y.size(-1)}; expected "
295-
f"{self._m}."
296-
)
297-
if Y.shape[-2] < 1:
298-
raise ValueError(f"Can't standardize with no observations. {Y.shape=}.")
299-
300-
elif Y.shape[-2] == 1:
328+
self._validate_training_inputs(Y=Y, Yvar=Yvar)
329+
if Y.shape[-2] == 1:
301330
stdvs = torch.ones(
302331
(*Y.shape[:-2], 1, Y.shape[-1]), dtype=Y.dtype, device=Y.device
303332
)
@@ -313,9 +342,12 @@ def forward(
313342
self.stdvs = stdvs
314343
self._stdvs_sq = stdvs.pow(2)
315344
self._is_trained = torch.tensor(True)
316-
317-
Y_tf = (Y - self.means) / self.stdvs
318-
Yvar_tf = Yvar / self._stdvs_sq if Yvar is not None else None
345+
include_stdvs_sq = Yvar is not None
346+
means, stdvs, stdvs_sq = self._get_per_input_means_stdvs(
347+
X=X, include_stdvs_sq=include_stdvs_sq
348+
)
349+
Y_tf = (Y - means) / stdvs
350+
Yvar_tf = Yvar / stdvs_sq if include_stdvs_sq else None
319351
return Y_tf, Yvar_tf
320352

321353
def subset_output(self, idcs: list[int]) -> OutcomeTransform:
@@ -376,9 +408,12 @@ def untransform(
376408
"(e.g. `transform(Y)`) before calling `untransform`, since "
377409
"means and standard deviations need to be computed."
378410
)
379-
380-
Y_utf = self.means + self.stdvs * Y
381-
Yvar_utf = self._stdvs_sq * Yvar if Yvar is not None else None
411+
include_stdvs_sq = Yvar is not None
412+
means, stdvs, stdvs_sq = self._get_per_input_means_stdvs(
413+
X=X, include_stdvs_sq=include_stdvs_sq
414+
)
415+
Y_utf = means + stdvs * Y
416+
Yvar_utf = stdvs_sq * Yvar if include_stdvs_sq else None
382417
return Y_utf, Yvar_utf
383418

384419
@property
@@ -433,8 +468,9 @@ def untransform_posterior(
433468
)
434469
# GPyTorchPosterior (TODO: Should we Lazy-evaluate the mean here as well?)
435470
mvn = posterior.distribution
436-
offset = self.means
437-
scale_fac = self.stdvs
471+
offset, scale_fac, _ = self._get_per_input_means_stdvs(
472+
X=X, include_stdvs_sq=False
473+
)
438474
if not posterior._is_mt:
439475
mean_tf = offset.squeeze(-1) + scale_fac.squeeze(-1) * mvn.mean
440476
scale_fac = scale_fac.squeeze(-1).expand_as(mean_tf)
@@ -449,7 +485,6 @@ def untransform_posterior(
449485

450486
if (
451487
not mvn.islazy
452-
# TODO: Figure out attribute namming weirdness here
453488
or mvn._MultivariateNormal__unbroadcasted_scale_tril is not None
454489
):
455490
# if already computed, we can save a lot of time using scale_tril
@@ -465,6 +500,197 @@ def untransform_posterior(
465500
return GPyTorchPosterior(mvn_tf)
466501

467502

503+
class StratifiedStandardize(Standardize):
504+
r"""Standardize outcomes (zero mean, unit variance) along stratification dimension.
505+
506+
This module is stateful: If in train mode, calling forward updates the
507+
module state (i.e. the mean/std normalizing constants). If in eval mode,
508+
calling forward simply applies the standardization using the current module
509+
state.
510+
"""
511+
512+
def __init__(
513+
self,
514+
task_values: Tensor,
515+
stratification_idx: int,
516+
batch_shape: torch.Size = torch.Size(), # noqa: B008
517+
min_stdv: float = 1e-8,
518+
# dtype: torch.dtype = torch.double,
519+
) -> None:
520+
r"""Standardize outcomes (zero mean, unit variance) along stratification dim.
521+
522+
Note: This currenlty only supports single output models
523+
(including multi-task models that have a single output).
524+
525+
Args:
526+
task_values: `t`-dim tensor of task values.
527+
stratification_idx: The index of the stratification dimension.
528+
batch_shape: The batch_shape of the training targets.
529+
min_stddv: The minimum standard deviation for which to perform
530+
standardization (if lower, only de-mean the data).
531+
"""
532+
OutcomeTransform.__init__(self)
533+
self._stratification_idx = stratification_idx
534+
task_values = task_values.unique(sorted=True)
535+
self.strata_mapping = get_task_value_remapping(task_values, dtype=torch.long)
536+
if self.strata_mapping is None:
537+
self.strata_mapping = task_values
538+
n_strata = self.strata_mapping.shape[0]
539+
self._min_stdv = min_stdv
540+
self.register_buffer("means", torch.zeros(*batch_shape, n_strata, 1))
541+
self.register_buffer("stdvs", torch.ones(*batch_shape, n_strata, 1))
542+
self.register_buffer("_stdvs_sq", torch.ones(*batch_shape, n_strata, 1))
543+
self.register_buffer("_is_trained", torch.tensor(False))
544+
self._batch_shape = batch_shape
545+
self._m = 1 # TODO: support multiple outputs
546+
self._outputs = None
547+
548+
def forward(
549+
self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None
550+
) -> tuple[Tensor, Tensor | None]:
551+
r"""Standardize outcomes.
552+
553+
If the module is in train mode, this updates the module state (i.e. the
554+
mean/std normalizing constants). If the module is in eval mode, simply
555+
applies the normalization using the module state.
556+
557+
Args:
558+
Y: A `batch_shape x n x m`-dim tensor of training targets.
559+
Yvar: A `batch_shape x n x m`-dim tensor of observation noises
560+
associated with the training targets (if applicable).
561+
X: A `batch_shape x n x d`-dim tensor of input parameters.
562+
563+
Returns:
564+
A two-tuple with the transformed outcomes:
565+
566+
- The transformed outcome observations.
567+
- The transformed observation noise (if applicable).
568+
"""
569+
if X is None:
570+
raise ValueError("X is required for StratifiedStandardize.")
571+
if self.training:
572+
self._validate_training_inputs(Y=Y, Yvar=Yvar)
573+
self.means = self.means.to(dtype=X.dtype, device=X.device)
574+
self.stdvs = self.stdvs.to(dtype=X.dtype, device=X.device)
575+
self._stdvs_sq = self._stdvs_sq.to(dtype=X.dtype, device=X.device)
576+
strata = X[..., self._stratification_idx].long()
577+
unique_strata = strata.unique()
578+
for s in unique_strata:
579+
mapped_strata = self.strata_mapping[s]
580+
mask = strata != s
581+
Y_strata = Y.clone()
582+
Y_strata[..., mask, :] = float("nan")
583+
stdvs = (
584+
torch.ones_like(Y_strata)
585+
if Y.shape[-2] == 1
586+
else nanstd(X=Y_strata, dim=-2)
587+
)
588+
stdvs = stdvs.where(
589+
stdvs >= self._min_stdv, torch.full_like(stdvs, 1.0)
590+
)
591+
means = Y_strata.nanmean(dim=-2)
592+
self.means[..., mapped_strata, :] = means
593+
self.stdvs[..., mapped_strata, :] = stdvs
594+
self._stdvs_sq[..., mapped_strata, :] = stdvs.pow(2)
595+
self._is_trained = torch.tensor(True)
596+
training = self.training
597+
self.training = False
598+
tf_Y, tf_Yvar = super().forward(Y=Y, Yvar=Yvar, X=X)
599+
self.training = training
600+
return tf_Y, tf_Yvar
601+
602+
def _get_per_input_means_stdvs(
603+
self, X: Tensor, include_stdvs_sq: bool
604+
) -> tuple[Tensor, Tensor, Tensor | None]:
605+
r"""Get per-input means and stdvs.
606+
607+
Args:
608+
X: A `batch_shape x n x d`-dim tensor of input parameters.
609+
include_stdvs_sq: Whether to include the stdvs squared.
610+
611+
Returns:
612+
A three-tuple with the per-input means and stdvs:
613+
614+
- The per-input means.
615+
- The per-input stdvs.
616+
- The per-input stdvs squared.
617+
"""
618+
strata = X[..., self._stratification_idx].long()
619+
mapped_strata = self.strata_mapping[strata].unsqueeze(-1)
620+
# get means and stdvs for each strata
621+
n_extra_batch_dims = mapped_strata.ndim - 2 - len(self._batch_shape)
622+
expand_shape = mapped_strata.shape[:n_extra_batch_dims] + self.means.shape
623+
means = torch.gather(
624+
input=self.means.expand(expand_shape),
625+
dim=-2,
626+
index=mapped_strata,
627+
)
628+
stdvs = torch.gather(
629+
input=self.stdvs.expand(expand_shape),
630+
dim=-2,
631+
index=mapped_strata,
632+
)
633+
if include_stdvs_sq:
634+
stdvs_sq = torch.gather(
635+
input=self._stdvs_sq.expand(expand_shape),
636+
dim=-2,
637+
index=mapped_strata,
638+
)
639+
else:
640+
stdvs_sq = None
641+
return means, stdvs, stdvs_sq
642+
643+
def subset_output(self, idcs: list[int]) -> OutcomeTransform:
644+
r"""Subset the transform along the output dimension.
645+
646+
Args:
647+
idcs: The output indices to subset the transform to.
648+
649+
Returns:
650+
The current outcome transform, subset to the specified output indices.
651+
"""
652+
raise NotImplementedError
653+
654+
def untransform(
655+
self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None
656+
) -> tuple[Tensor, Tensor | None]:
657+
r"""Un-standardize outcomes.
658+
659+
Args:
660+
Y: A `batch_shape x n x m`-dim tensor of standardized targets.
661+
Yvar: A `batch_shape x n x m`-dim tensor of standardized observation
662+
noises associated with the targets (if applicable).
663+
X: A `batch_shape x n x d`-dim tensor of input parameters.
664+
665+
Returns:
666+
A two-tuple with the un-standardized outcomes:
667+
668+
- The un-standardized outcome observations.
669+
- The un-standardized observation noise (if applicable).
670+
"""
671+
if X is None:
672+
raise ValueError("X is required for StratifiedStandardize.")
673+
return super().untransform(Y=Y, Yvar=Yvar, X=X)
674+
675+
def untransform_posterior(
676+
self, posterior: Posterior, X: Tensor | None = None
677+
) -> GPyTorchPosterior | TransformedPosterior:
678+
r"""Un-standardize the posterior.
679+
680+
Args:
681+
posterior: A posterior in the standardized space.
682+
X: A `batch_shape x n x d`-dim tensor of training inputs (if applicable).
683+
684+
Returns:
685+
The un-standardized posterior. If the input posterior is a
686+
`GPyTorchPosterior`, return a `GPyTorchPosterior`. Otherwise, return a
687+
`TransformedPosterior`.
688+
"""
689+
if X is None:
690+
raise ValueError("X is required for StratifiedStandardize.")
691+
return super().untransform_posterior(posterior=posterior, X=X)
692+
693+
468694
class Log(OutcomeTransform):
469695
r"""Log-transform outcomes.
470696

0 commit comments

Comments
 (0)