27
27
28
28
import torch
29
29
from botorch .models .transforms .utils import (
30
+ nanstd ,
30
31
norm_to_lognorm_mean ,
31
32
norm_to_lognorm_variance ,
32
33
)
34
+ from botorch .models .utils .assorted import get_task_value_remapping
33
35
from botorch .posteriors import GPyTorchPosterior , Posterior , TransformedPosterior
34
36
from botorch .utils .transforms import normalize_indices
35
37
from linear_operator .operators import CholLinearOperator , DiagLinearOperator
@@ -259,6 +261,46 @@ def __init__(
259
261
self ._batch_shape = batch_shape
260
262
self ._min_stdv = min_stdv
261
263
264
+ def _get_per_input_means_stdvs (
265
+ self , X : Tensor , include_stdvs_sq : bool
266
+ ) -> tuple [Tensor , Tensor , Tensor | None ]:
267
+ r"""Get per-input means and stdvs.
268
+
269
+ Args:
270
+ X: A `batch_shape x n x d`-dim tensor of input parameters.
271
+ include_stdvs_sq: Whether to include the stdvs squared.
272
+ This parameter is not used by this method
273
+
274
+ Returns:
275
+ A three-tuple with the means and stdvs:
276
+
277
+ - The per-input means.
278
+ - The per-input stdvs.
279
+ - The per-input stdvs squared.
280
+ """
281
+ return self .means , self .stdvs , self ._stdvs_sq
282
+
283
+ def _validate_training_inputs (self , Y : Tensor , Yvar : Tensor | None = None ) -> None :
284
+ """Validate training inputs.
285
+
286
+ Args:
287
+ Y: A `batch_shape x n x m`-dim tensor of training targets.
288
+ Yvar: A `batch_shape x n x m`-dim tensor of observation noises.
289
+ """
290
+ if Y .shape [:- 2 ] != self ._batch_shape :
291
+ raise RuntimeError (
292
+ f"Expected Y.shape[:-2] to be { self ._batch_shape } , matching "
293
+ f"the `batch_shape` argument to `{ self .__class__ .__name__ } `, but got "
294
+ f"Y.shape[:-2]={ Y .shape [:- 2 ]} ."
295
+ )
296
+ elif Y .shape [- 2 ] < 1 :
297
+ raise ValueError (f"Can't standardize with no observations. { Y .shape = } ." )
298
+ elif Y .size (- 1 ) != self ._m :
299
+ raise RuntimeError (
300
+ f"Wrong output dimension. Y.size(-1) is { Y .size (- 1 )} ; expected "
301
+ f"{ self ._m } ."
302
+ )
303
+
262
304
def forward (
263
305
self , Y : Tensor , Yvar : Tensor | None = None , X : Tensor | None = None
264
306
) -> tuple [Tensor , Tensor | None ]:
@@ -283,21 +325,8 @@ def forward(
283
325
- The transformed observation noise (if applicable).
284
326
"""
285
327
if self .training :
286
- if Y .shape [:- 2 ] != self ._batch_shape :
287
- raise RuntimeError (
288
- f"Expected Y.shape[:-2] to be { self ._batch_shape } , matching "
289
- "the `batch_shape` argument to `Standardize`, but got "
290
- f"Y.shape[:-2]={ Y .shape [:- 2 ]} ."
291
- )
292
- if Y .size (- 1 ) != self ._m :
293
- raise RuntimeError (
294
- f"Wrong output dimension. Y.size(-1) is { Y .size (- 1 )} ; expected "
295
- f"{ self ._m } ."
296
- )
297
- if Y .shape [- 2 ] < 1 :
298
- raise ValueError (f"Can't standardize with no observations. { Y .shape = } ." )
299
-
300
- elif Y .shape [- 2 ] == 1 :
328
+ self ._validate_training_inputs (Y = Y , Yvar = Yvar )
329
+ if Y .shape [- 2 ] == 1 :
301
330
stdvs = torch .ones (
302
331
(* Y .shape [:- 2 ], 1 , Y .shape [- 1 ]), dtype = Y .dtype , device = Y .device
303
332
)
@@ -313,9 +342,12 @@ def forward(
313
342
self .stdvs = stdvs
314
343
self ._stdvs_sq = stdvs .pow (2 )
315
344
self ._is_trained = torch .tensor (True )
316
-
317
- Y_tf = (Y - self .means ) / self .stdvs
318
- Yvar_tf = Yvar / self ._stdvs_sq if Yvar is not None else None
345
+ include_stdvs_sq = Yvar is not None
346
+ means , stdvs , stdvs_sq = self ._get_per_input_means_stdvs (
347
+ X = X , include_stdvs_sq = include_stdvs_sq
348
+ )
349
+ Y_tf = (Y - means ) / stdvs
350
+ Yvar_tf = Yvar / stdvs_sq if include_stdvs_sq else None
319
351
return Y_tf , Yvar_tf
320
352
321
353
def subset_output (self , idcs : list [int ]) -> OutcomeTransform :
@@ -376,9 +408,12 @@ def untransform(
376
408
"(e.g. `transform(Y)`) before calling `untransform`, since "
377
409
"means and standard deviations need to be computed."
378
410
)
379
-
380
- Y_utf = self .means + self .stdvs * Y
381
- Yvar_utf = self ._stdvs_sq * Yvar if Yvar is not None else None
411
+ include_stdvs_sq = Yvar is not None
412
+ means , stdvs , stdvs_sq = self ._get_per_input_means_stdvs (
413
+ X = X , include_stdvs_sq = include_stdvs_sq
414
+ )
415
+ Y_utf = means + stdvs * Y
416
+ Yvar_utf = stdvs_sq * Yvar if include_stdvs_sq else None
382
417
return Y_utf , Yvar_utf
383
418
384
419
@property
@@ -433,8 +468,9 @@ def untransform_posterior(
433
468
)
434
469
# GPyTorchPosterior (TODO: Should we Lazy-evaluate the mean here as well?)
435
470
mvn = posterior .distribution
436
- offset = self .means
437
- scale_fac = self .stdvs
471
+ offset , scale_fac , _ = self ._get_per_input_means_stdvs (
472
+ X = X , include_stdvs_sq = False
473
+ )
438
474
if not posterior ._is_mt :
439
475
mean_tf = offset .squeeze (- 1 ) + scale_fac .squeeze (- 1 ) * mvn .mean
440
476
scale_fac = scale_fac .squeeze (- 1 ).expand_as (mean_tf )
@@ -449,7 +485,6 @@ def untransform_posterior(
449
485
450
486
if (
451
487
not mvn .islazy
452
- # TODO: Figure out attribute namming weirdness here
453
488
or mvn ._MultivariateNormal__unbroadcasted_scale_tril is not None
454
489
):
455
490
# if already computed, we can save a lot of time using scale_tril
@@ -465,6 +500,197 @@ def untransform_posterior(
465
500
return GPyTorchPosterior (mvn_tf )
466
501
467
502
503
+ class StratifiedStandardize (Standardize ):
504
+ r"""Standardize outcomes (zero mean, unit variance) along stratification dimension.
505
+
506
+ This module is stateful: If in train mode, calling forward updates the
507
+ module state (i.e. the mean/std normalizing constants). If in eval mode,
508
+ calling forward simply applies the standardization using the current module
509
+ state.
510
+ """
511
+
512
+ def __init__ (
513
+ self ,
514
+ task_values : Tensor ,
515
+ stratification_idx : int ,
516
+ batch_shape : torch .Size = torch .Size (), # noqa: B008
517
+ min_stdv : float = 1e-8 ,
518
+ # dtype: torch.dtype = torch.double,
519
+ ) -> None :
520
+ r"""Standardize outcomes (zero mean, unit variance) along stratification dim.
521
+
522
+ Note: This currenlty only supports single output models
523
+ (including multi-task models that have a single output).
524
+
525
+ Args:
526
+ task_values: `t`-dim tensor of task values.
527
+ stratification_idx: The index of the stratification dimension.
528
+ batch_shape: The batch_shape of the training targets.
529
+ min_stddv: The minimum standard deviation for which to perform
530
+ standardization (if lower, only de-mean the data).
531
+ """
532
+ OutcomeTransform .__init__ (self )
533
+ self ._stratification_idx = stratification_idx
534
+ task_values = task_values .unique (sorted = True )
535
+ self .strata_mapping = get_task_value_remapping (task_values , dtype = torch .long )
536
+ if self .strata_mapping is None :
537
+ self .strata_mapping = task_values
538
+ n_strata = self .strata_mapping .shape [0 ]
539
+ self ._min_stdv = min_stdv
540
+ self .register_buffer ("means" , torch .zeros (* batch_shape , n_strata , 1 ))
541
+ self .register_buffer ("stdvs" , torch .ones (* batch_shape , n_strata , 1 ))
542
+ self .register_buffer ("_stdvs_sq" , torch .ones (* batch_shape , n_strata , 1 ))
543
+ self .register_buffer ("_is_trained" , torch .tensor (False ))
544
+ self ._batch_shape = batch_shape
545
+ self ._m = 1 # TODO: support multiple outputs
546
+ self ._outputs = None
547
+
548
+ def forward (
549
+ self , Y : Tensor , Yvar : Tensor | None = None , X : Tensor | None = None
550
+ ) -> tuple [Tensor , Tensor | None ]:
551
+ r"""Standardize outcomes.
552
+
553
+ If the module is in train mode, this updates the module state (i.e. the
554
+ mean/std normalizing constants). If the module is in eval mode, simply
555
+ applies the normalization using the module state.
556
+
557
+ Args:
558
+ Y: A `batch_shape x n x m`-dim tensor of training targets.
559
+ Yvar: A `batch_shape x n x m`-dim tensor of observation noises
560
+ associated with the training targets (if applicable).
561
+ X: A `batch_shape x n x d`-dim tensor of input parameters.
562
+
563
+ Returns:
564
+ A two-tuple with the transformed outcomes:
565
+
566
+ - The transformed outcome observations.
567
+ - The transformed observation noise (if applicable).
568
+ """
569
+ if X is None :
570
+ raise ValueError ("X is required for StratifiedStandardize." )
571
+ if self .training :
572
+ self ._validate_training_inputs (Y = Y , Yvar = Yvar )
573
+ self .means = self .means .to (dtype = X .dtype , device = X .device )
574
+ self .stdvs = self .stdvs .to (dtype = X .dtype , device = X .device )
575
+ self ._stdvs_sq = self ._stdvs_sq .to (dtype = X .dtype , device = X .device )
576
+ strata = X [..., self ._stratification_idx ].long ()
577
+ unique_strata = strata .unique ()
578
+ for s in unique_strata :
579
+ mapped_strata = self .strata_mapping [s ]
580
+ mask = strata != s
581
+ Y_strata = Y .clone ()
582
+ Y_strata [..., mask , :] = float ("nan" )
583
+ stdvs = (
584
+ torch .ones_like (Y_strata )
585
+ if Y .shape [- 2 ] == 1
586
+ else nanstd (X = Y_strata , dim = - 2 )
587
+ )
588
+ stdvs = stdvs .where (
589
+ stdvs >= self ._min_stdv , torch .full_like (stdvs , 1.0 )
590
+ )
591
+ means = Y_strata .nanmean (dim = - 2 )
592
+ self .means [..., mapped_strata , :] = means
593
+ self .stdvs [..., mapped_strata , :] = stdvs
594
+ self ._stdvs_sq [..., mapped_strata , :] = stdvs .pow (2 )
595
+ self ._is_trained = torch .tensor (True )
596
+ training = self .training
597
+ self .training = False
598
+ tf_Y , tf_Yvar = super ().forward (Y = Y , Yvar = Yvar , X = X )
599
+ self .training = training
600
+ return tf_Y , tf_Yvar
601
+
602
+ def _get_per_input_means_stdvs (
603
+ self , X : Tensor , include_stdvs_sq : bool
604
+ ) -> tuple [Tensor , Tensor , Tensor | None ]:
605
+ r"""Get per-input means and stdvs.
606
+
607
+ Args:
608
+ X: A `batch_shape x n x d`-dim tensor of input parameters.
609
+ include_stdvs_sq: Whether to include the stdvs squared.
610
+
611
+ Returns:
612
+ A three-tuple with the per-input means and stdvs:
613
+
614
+ - The per-input means.
615
+ - The per-input stdvs.
616
+ - The per-input stdvs squared.
617
+ """
618
+ strata = X [..., self ._stratification_idx ].long ()
619
+ mapped_strata = self .strata_mapping [strata ].unsqueeze (- 1 )
620
+ # get means and stdvs for each strata
621
+ n_extra_batch_dims = mapped_strata .ndim - 2 - len (self ._batch_shape )
622
+ expand_shape = mapped_strata .shape [:n_extra_batch_dims ] + self .means .shape
623
+ means = torch .gather (
624
+ input = self .means .expand (expand_shape ),
625
+ dim = - 2 ,
626
+ index = mapped_strata ,
627
+ )
628
+ stdvs = torch .gather (
629
+ input = self .stdvs .expand (expand_shape ),
630
+ dim = - 2 ,
631
+ index = mapped_strata ,
632
+ )
633
+ if include_stdvs_sq :
634
+ stdvs_sq = torch .gather (
635
+ input = self ._stdvs_sq .expand (expand_shape ),
636
+ dim = - 2 ,
637
+ index = mapped_strata ,
638
+ )
639
+ else :
640
+ stdvs_sq = None
641
+ return means , stdvs , stdvs_sq
642
+
643
+ def subset_output (self , idcs : list [int ]) -> OutcomeTransform :
644
+ r"""Subset the transform along the output dimension.
645
+
646
+ Args:
647
+ idcs: The output indices to subset the transform to.
648
+
649
+ Returns:
650
+ The current outcome transform, subset to the specified output indices.
651
+ """
652
+ raise NotImplementedError
653
+
654
+ def untransform (
655
+ self , Y : Tensor , Yvar : Tensor | None = None , X : Tensor | None = None
656
+ ) -> tuple [Tensor , Tensor | None ]:
657
+ r"""Un-standardize outcomes.
658
+
659
+ Args:
660
+ Y: A `batch_shape x n x m`-dim tensor of standardized targets.
661
+ Yvar: A `batch_shape x n x m`-dim tensor of standardized observation
662
+ noises associated with the targets (if applicable).
663
+ X: A `batch_shape x n x d`-dim tensor of input parameters.
664
+
665
+ Returns:
666
+ A two-tuple with the un-standardized outcomes:
667
+
668
+ - The un-standardized outcome observations.
669
+ - The un-standardized observation noise (if applicable).
670
+ """
671
+ if X is None :
672
+ raise ValueError ("X is required for StratifiedStandardize." )
673
+ return super ().untransform (Y = Y , Yvar = Yvar , X = X )
674
+
675
+ def untransform_posterior (
676
+ self , posterior : Posterior , X : Tensor | None = None
677
+ ) -> GPyTorchPosterior | TransformedPosterior :
678
+ r"""Un-standardize the posterior.
679
+
680
+ Args:
681
+ posterior: A posterior in the standardized space.
682
+ X: A `batch_shape x n x d`-dim tensor of training inputs (if applicable).
683
+
684
+ Returns:
685
+ The un-standardized posterior. If the input posterior is a
686
+ `GPyTorchPosterior`, return a `GPyTorchPosterior`. Otherwise, return a
687
+ `TransformedPosterior`.
688
+ """
689
+ if X is None :
690
+ raise ValueError ("X is required for StratifiedStandardize." )
691
+ return super ().untransform_posterior (posterior = posterior , X = X )
692
+
693
+
468
694
class Log (OutcomeTransform ):
469
695
r"""Log-transform outcomes.
470
696
0 commit comments