1515from __future__ import annotations
1616
1717"""Gaussian Process models."""
18-
1918import logging
19+ from typing import Iterable , Optional , Union
20+
21+ import chex
2022import equinox as eqx
2123import jax
2224from tensorflow_probability .substrates import jax as tfp
25+ from vizier ._src .algorithms .designers .gp import acquisitions
26+ from vizier ._src .algorithms .designers .gp import transfer_learning as vtl
2327from vizier ._src .jax import stochastic_process_model as sp
2428from vizier ._src .jax import types
2529from vizier ._src .jax .models import tuned_gp_models
2832tfd = tfp .distributions
2933
3034
35+ class GPTrainingSpec (eqx .Module ):
36+ """Encapsulates all the information needed to train a singular GP model.
37+
38+ Attributes:
39+ ard_optimizer: An `Optimizer` which should return a batch of hyperparameters
40+ to be ensembled.
41+ ard_rng: PRNG key for the ARD optimization.
42+ coroutine: The model coroutine.
43+ ensemble_size: If set, ensembles `ensemble_size` GP models together.
44+ ard_random_restarts: The number of random restarts.
45+ """
46+
47+ ard_optimizer : optimizers .Optimizer [types .ParameterDict ]
48+ ard_rng : jax .random .KeyArray
49+ coroutine : sp .ModelCoroutine
50+ ensemble_size : int = eqx .field (static = True , default = 1 )
51+ ard_random_restarts : int = eqx .field (
52+ static = True , default = optimizers .DEFAULT_RANDOM_RESTARTS
53+ )
54+
55+
3156class GPState (eqx .Module ):
32- """Stores a GP model `predictive` and the data used during training .
57+ """A GP model and its training data. Implements `Predictive` .
3358
3459 The data is kept around to deduce degrees of freedom and other related
35- metrics.
60+ metrics. This implements `Predictive`, so that it and any of its dervied
61+ classes like `StackedResidualGP` can be used as a `Predictive` in
62+ `acquisitions.py`.
3663 """
3764
3865 predictive : sp .UniformEnsemblePredictive
3966 data : types .ModelData
4067
68+ def predict_with_aux (
69+ self , features : types .ModelInput
70+ ) -> tuple [tfd .Distribution , chex .ArrayTree ]:
71+ return self .predictive .predict_with_aux (features )
72+
73+ def num_hyperparameters (self ) -> int :
74+ """Returns the number of hyperparameters optimized on `data`."""
75+
76+ # For a GP model, this is feature dimensionality + 2
77+ # (length scales, amplitude, observation noise)
78+ # TODO: Compute this from the params returned by the ard
79+ # optimizer
80+ return (
81+ self .data .features .continuous .shape [1 ] # (num_samples, num_features)
82+ + self .data .features .categorical .shape [1 ] # (num_samples, num_features)
83+ + 2
84+ )
85+
86+
87+ class StackedResidualGP (GPState ):
88+ """GP that implements the `predictive` interface and contains stacked GPs.
89+
90+ This GP handles sequential transfer learning. This holds one or no base
91+ (prior) GPs, along with a current top-level GP. The training process is such
92+ that the 'top' GP is trained on the residuals of the predictions from the
93+ base GP. The inference process is such that the predictions of the base
94+ GP and the 'top' GP are combined together.
95+
96+ The base GP may also have its own base GPs and be a `StackedResidualGP`.
97+ """
98+
99+ # `base_gp` refers to a GP trained and conditioned on previous data for
100+ # transfer learning. The top level GP is trained on the residuals from
101+ # `base_gp` on `data`.
102+ # If `None`, no transfer learning is used and all predictions happen through
103+ # `predictive`.
104+ base_gp : Optional [GPState ] = None
105+
106+ def predict_with_aux (
107+ self , features : types .ModelInput
108+ ) -> tuple [tfd .Distribution , chex .ArrayTree ]:
109+ # Override the existing implementation of `predict_with_aux` to handle
110+ # combining `predictive` with `base_gp`.
111+ if not self .base_gp :
112+ return super ().predict_with_aux (features )
113+
114+ base_pred_dist , base_aux = self .base_gp .predict_with_aux (features )
115+ top_pred_dist , top_aux = self .predictive .predict_with_aux (features )
116+
117+ base_pred = vtl .TransferPredictionState (
118+ pred = base_pred_dist ,
119+ aux = base_aux ,
120+ training_data_count = self .base_gp .data .labels .shape [0 ],
121+ num_hyperparameters = self .num_hyperparameters (),
122+ )
123+ top_pred = vtl .TransferPredictionState (
124+ pred = top_pred_dist ,
125+ aux = top_aux ,
126+ training_data_count = self .data .labels .shape [0 ],
127+ num_hyperparameters = self .num_hyperparameters (),
128+ )
129+
130+ # TODO: Decide what to do with
131+ # `expected_base_stddev_mismatch` - currently set to default.
132+ comb_dist , aux = vtl .combine_predictions_with_aux (
133+ top_pred = top_pred , base_pred = base_pred
134+ )
135+
136+ return comb_dist , aux
137+
41138
42139def get_vizier_gp_coroutine (
43140 features : types .ModelInput , * , linear_coef : float = 0.0
@@ -60,45 +157,32 @@ def get_vizier_gp_coroutine(
60157 return tuned_gp_models .VizierGaussianProcess .build_model (features ).coroutine
61158
62159
63- def train_gp (
64- data : types .ModelData ,
65- ard_optimizer : optimizers .Optimizer [types .ParameterDict ],
66- ard_rng : jax .random .KeyArray ,
67- * ,
68- coroutine : sp .ModelCoroutine ,
69- ensemble_size : int = 1 ,
70- ard_random_restarts : int = optimizers .DEFAULT_RANDOM_RESTARTS ,
71- ) -> GPState :
160+ def _train_gp (spec : GPTrainingSpec , data : types .ModelData ) -> GPState :
72161 """Trains a Gaussian Process model.
73162
74163 1. Performs ARD to find the best model parameters.
75164 2. Pre-computes the Cholesky decomposition for the model.
76165
77166 Args:
78- data: Data to train the GP model(s) on.
79- ard_optimizer: An `Optimizer` which should return a batch of hyperparameters
80- to be ensembled.
81- ard_rng: PRNG key for the ARD optimization.
82- coroutine: The model coroutine.
83- ensemble_size: If set, ensembles `ensemble_size` GP models together.
84- ard_random_restarts: The number of random restarts.
167+ spec: Spec required to train the GP. See `GPTrainingSpec` for more info.
168+ data: Data on which to train the GP.
85169
86170 Returns:
87171 The trained GP model.
88172 """
89- model = sp .CoroutineWithData (coroutine , data )
173+ model = sp .CoroutineWithData (spec . coroutine , data )
90174
91175 # Optimize the parameters
92- ard_rngs = jax .random .split (ard_rng , ard_random_restarts + 1 )
93- best_params , _ = ard_optimizer (
176+ ard_rngs = jax .random .split (spec . ard_rng , spec . ard_random_restarts + 1 )
177+ best_params , _ = spec . ard_optimizer (
94178 eqx .filter_jit (eqx .filter_vmap (model .setup ))(ard_rngs [1 :]),
95179 model .loss_with_aux ,
96180 ard_rngs [0 ],
97181 constraints = model .constraints (),
98- best_n = ensemble_size or 1 ,
182+ best_n = spec . ensemble_size or 1 ,
99183 )
100184 best_models = sp .StochasticProcessWithCoroutine (
101- coroutine = coroutine , params = best_params
185+ coroutine = spec . coroutine , params = best_params
102186 )
103187 # Logging for debugging purposes.
104188 logging .info (
@@ -107,5 +191,145 @@ def train_gp(
107191 predictive = sp .UniformEnsemblePredictive (
108192 eqx .filter_jit (best_models .precompute_predictive )(data )
109193 )
110-
111194 return GPState (predictive = predictive , data = data )
195+
196+
197+ @jax .jit
198+ def _pred_mean (
199+ pred : acquisitions .Predictive , features : types .ModelInput
200+ ) -> types .Array :
201+ """Returns the mean of the predictions from `pred` on `features`.
202+
203+ Workaround while `eqx.filter_jit(pred.pred_with_aux)(features)` is broken
204+ due to a bug in tensorflow probability.
205+
206+ Args:
207+ pred: `Predictive` to predict with.
208+ features: Xs to predict on.
209+
210+ Returns:
211+ Means of the predictions from `pred` on `features`.
212+ """
213+ return pred .predict_with_aux (features )[0 ].mean ()
214+
215+
216+ def _train_stacked_residual_gp (
217+ base_gp : GPState ,
218+ spec : GPTrainingSpec ,
219+ data : types .ModelData ,
220+ ) -> StackedResidualGP :
221+ """Trains a `StackedResidualGP`.
222+
223+ Completes the following steps in order:
224+ 1. Uses `base_gp` to predict on the `data`
225+ 2. Computes the residuals from the above predictions
226+ 3. Trains a top-level GP on the above residuals
227+ 4. Returns a `StackedResidualGP` combining the base GP and newly-trained
228+ GP.
229+
230+ Args:
231+ base_gp: The predictive to use as the base GP for the `StackedResidualGP`
232+ training.
233+ spec: Training spec for the top level GP.
234+ data: Training data for the top level GP.
235+
236+ Returns:
237+ The trained `StackedResidualGP`.
238+ """
239+ # Compute the residuals of `data` as predicted by `base_gp`
240+ pred_means = _pred_mean (base_gp , data .features )
241+
242+ has_no_padding = ~ (
243+ data .features .continuous .is_missing [0 ]
244+ | data .features .categorical .is_missing [0 ]
245+ | data .labels .is_missing [0 ]
246+ )
247+
248+ # Scope this to non-padded predictions only.
249+ pred_means_no_padding = pred_means [has_no_padding ]
250+ residuals = (
251+ data .labels .unpad ().reshape (pred_means_no_padding .shape )
252+ - pred_means_no_padding
253+ )
254+
255+ # Train on the re-padded residuals
256+ residual_labels = types .PaddedArray .from_array (
257+ array = residuals ,
258+ target_shape = data .labels .shape ,
259+ fill_value = data .labels .fill_value ,
260+ )
261+ data_with_residuals = types .ModelData (
262+ features = data .features , labels = residual_labels
263+ )
264+
265+ top_gp = _train_gp (spec = spec , data = data_with_residuals )
266+ return StackedResidualGP (
267+ predictive = top_gp .predictive ,
268+ data = top_gp .data ,
269+ base_gp = base_gp ,
270+ )
271+
272+
273+ def train_gp (
274+ spec : Union [GPTrainingSpec , Iterable [GPTrainingSpec ]],
275+ data : Union [types .ModelData , Iterable [types .ModelData ]],
276+ ) -> GPState :
277+ """Trains a Gaussian Process model.
278+
279+ If `spec` contains multiple elements, each will be used to train a
280+ `StackedResidualGP`, sequentially. The last entry will be used to train the
281+ first GP, and then subsequent GPs will be trained on the residuals from the
282+ previous GP. This process completes in reverse order, such that `spec[-1]` is
283+ the first GP trained and `spec[0]` is the last GP trained.
284+
285+ spec[0] and data[0] make up the top-level GP, and spec[1:] and data[1:] define
286+ the priors in context of transfer learning.
287+
288+ Args:
289+ spec: Specification for how to train a GP model. If multiple specs are
290+ provided, transfer learning will train multiple models and combine into a
291+ single GP.
292+ data: Data on which to train GPs. NOTE: `spec` and `data` must be of the
293+ same shape. Trains a GP on `data[i]` with `spec[i]`.
294+
295+ Returns:
296+ The trained GP model.
297+ """
298+ is_singleton_spec = isinstance (spec , GPTrainingSpec )
299+ is_singleton_data = isinstance (data , types .ModelData )
300+ if is_singleton_spec != is_singleton_data :
301+ raise ValueError (
302+ '`train_gp` expected the shapes of `spec` and `data` to be identical.'
303+ f' Instead got `data` { data } but `spec` { spec } .'
304+ )
305+
306+ if is_singleton_spec and is_singleton_data :
307+ return _train_gp (spec = spec , data = data )
308+
309+ if len (spec ) != len (data ):
310+ raise ValueError (
311+ '`train_gp` expected the shapes of `spec` and `data` to be identical.'
312+ f' Instead got `spec` of length { len (spec )} but `data` of length'
313+ f' { len (data )} . `spec` was { spec } and `data` was { data } .'
314+ )
315+
316+ curr_gp : Optional [GPState ] = None
317+ for curr_spec , curr_data in reversed (list (zip (spec , data ))):
318+ if curr_gp is None :
319+ # We are on the first iteration.
320+ curr_gp = _train_gp (spec = curr_spec , data = curr_data )
321+ else :
322+ # Otherwise, we have a base GP to use - the GP trained on the last
323+ # iteration.
324+ curr_gp = _train_stacked_residual_gp (
325+ base_gp = curr_gp ,
326+ spec = curr_spec ,
327+ data = curr_data ,
328+ )
329+
330+ if curr_gp is None :
331+ raise ValueError (
332+ f'Failed to train a GP with provided training spec: { spec } and'
333+ f' data: { data } . `curr_gp` was never updated. This should never happen.'
334+ )
335+ return curr_gp
0 commit comments