Skip to content

Commit b2f7354

Browse files
vizier-teamcopybara-github
authored andcommitted
Internal Change
PiperOrigin-RevId: 552941861
1 parent e8d9080 commit b2f7354

File tree

6 files changed

+963
-88
lines changed

6 files changed

+963
-88
lines changed

vizier/_src/algorithms/designers/gp/gp_models.py

Lines changed: 249 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@
1515
from __future__ import annotations
1616

1717
"""Gaussian Process models."""
18-
1918
import logging
19+
from typing import Iterable, Optional, Union
20+
21+
import chex
2022
import equinox as eqx
2123
import jax
2224
from tensorflow_probability.substrates import jax as tfp
25+
from vizier._src.algorithms.designers.gp import acquisitions
26+
from vizier._src.algorithms.designers.gp import transfer_learning as vtl
2327
from vizier._src.jax import stochastic_process_model as sp
2428
from vizier._src.jax import types
2529
from vizier._src.jax.models import tuned_gp_models
@@ -28,16 +32,109 @@
2832
tfd = tfp.distributions
2933

3034

35+
class GPTrainingSpec(eqx.Module):
36+
"""Encapsulates all the information needed to train a singular GP model.
37+
38+
Attributes:
39+
ard_optimizer: An `Optimizer` which should return a batch of hyperparameters
40+
to be ensembled.
41+
ard_rng: PRNG key for the ARD optimization.
42+
coroutine: The model coroutine.
43+
ensemble_size: If set, ensembles `ensemble_size` GP models together.
44+
ard_random_restarts: The number of random restarts.
45+
"""
46+
47+
ard_optimizer: optimizers.Optimizer[types.ParameterDict]
48+
ard_rng: jax.random.KeyArray
49+
coroutine: sp.ModelCoroutine
50+
ensemble_size: int = eqx.field(static=True, default=1)
51+
ard_random_restarts: int = eqx.field(
52+
static=True, default=optimizers.DEFAULT_RANDOM_RESTARTS
53+
)
54+
55+
3156
class GPState(eqx.Module):
32-
"""Stores a GP model `predictive` and the data used during training.
57+
"""A GP model and its training data. Implements `Predictive`.
3358
3459
The data is kept around to deduce degrees of freedom and other related
35-
metrics.
60+
metrics. This implements `Predictive`, so that it and any of its dervied
61+
classes like `StackedResidualGP` can be used as a `Predictive` in
62+
`acquisitions.py`.
3663
"""
3764

3865
predictive: sp.UniformEnsemblePredictive
3966
data: types.ModelData
4067

68+
def predict_with_aux(
69+
self, features: types.ModelInput
70+
) -> tuple[tfd.Distribution, chex.ArrayTree]:
71+
return self.predictive.predict_with_aux(features)
72+
73+
def num_hyperparameters(self) -> int:
74+
"""Returns the number of hyperparameters optimized on `data`."""
75+
76+
# For a GP model, this is feature dimensionality + 2
77+
# (length scales, amplitude, observation noise)
78+
# TODO: Compute this from the params returned by the ard
79+
# optimizer
80+
return (
81+
self.data.features.continuous.shape[1] # (num_samples, num_features)
82+
+ self.data.features.categorical.shape[1] # (num_samples, num_features)
83+
+ 2
84+
)
85+
86+
87+
class StackedResidualGP(GPState):
88+
"""GP that implements the `predictive` interface and contains stacked GPs.
89+
90+
This GP handles sequential transfer learning. This holds one or no base
91+
(prior) GPs, along with a current top-level GP. The training process is such
92+
that the 'top' GP is trained on the residuals of the predictions from the
93+
base GP. The inference process is such that the predictions of the base
94+
GP and the 'top' GP are combined together.
95+
96+
The base GP may also have its own base GPs and be a `StackedResidualGP`.
97+
"""
98+
99+
# `base_gp` refers to a GP trained and conditioned on previous data for
100+
# transfer learning. The top level GP is trained on the residuals from
101+
# `base_gp` on `data`.
102+
# If `None`, no transfer learning is used and all predictions happen through
103+
# `predictive`.
104+
base_gp: Optional[GPState] = None
105+
106+
def predict_with_aux(
107+
self, features: types.ModelInput
108+
) -> tuple[tfd.Distribution, chex.ArrayTree]:
109+
# Override the existing implementation of `predict_with_aux` to handle
110+
# combining `predictive` with `base_gp`.
111+
if not self.base_gp:
112+
return super().predict_with_aux(features)
113+
114+
base_pred_dist, base_aux = self.base_gp.predict_with_aux(features)
115+
top_pred_dist, top_aux = self.predictive.predict_with_aux(features)
116+
117+
base_pred = vtl.TransferPredictionState(
118+
pred=base_pred_dist,
119+
aux=base_aux,
120+
training_data_count=self.base_gp.data.labels.shape[0],
121+
num_hyperparameters=self.num_hyperparameters(),
122+
)
123+
top_pred = vtl.TransferPredictionState(
124+
pred=top_pred_dist,
125+
aux=top_aux,
126+
training_data_count=self.data.labels.shape[0],
127+
num_hyperparameters=self.num_hyperparameters(),
128+
)
129+
130+
# TODO: Decide what to do with
131+
# `expected_base_stddev_mismatch` - currently set to default.
132+
comb_dist, aux = vtl.combine_predictions_with_aux(
133+
top_pred=top_pred, base_pred=base_pred
134+
)
135+
136+
return comb_dist, aux
137+
41138

42139
def get_vizier_gp_coroutine(
43140
features: types.ModelInput, *, linear_coef: float = 0.0
@@ -60,45 +157,32 @@ def get_vizier_gp_coroutine(
60157
return tuned_gp_models.VizierGaussianProcess.build_model(features).coroutine
61158

62159

63-
def train_gp(
64-
data: types.ModelData,
65-
ard_optimizer: optimizers.Optimizer[types.ParameterDict],
66-
ard_rng: jax.random.KeyArray,
67-
*,
68-
coroutine: sp.ModelCoroutine,
69-
ensemble_size: int = 1,
70-
ard_random_restarts: int = optimizers.DEFAULT_RANDOM_RESTARTS,
71-
) -> GPState:
160+
def _train_gp(spec: GPTrainingSpec, data: types.ModelData) -> GPState:
72161
"""Trains a Gaussian Process model.
73162
74163
1. Performs ARD to find the best model parameters.
75164
2. Pre-computes the Cholesky decomposition for the model.
76165
77166
Args:
78-
data: Data to train the GP model(s) on.
79-
ard_optimizer: An `Optimizer` which should return a batch of hyperparameters
80-
to be ensembled.
81-
ard_rng: PRNG key for the ARD optimization.
82-
coroutine: The model coroutine.
83-
ensemble_size: If set, ensembles `ensemble_size` GP models together.
84-
ard_random_restarts: The number of random restarts.
167+
spec: Spec required to train the GP. See `GPTrainingSpec` for more info.
168+
data: Data on which to train the GP.
85169
86170
Returns:
87171
The trained GP model.
88172
"""
89-
model = sp.CoroutineWithData(coroutine, data)
173+
model = sp.CoroutineWithData(spec.coroutine, data)
90174

91175
# Optimize the parameters
92-
ard_rngs = jax.random.split(ard_rng, ard_random_restarts + 1)
93-
best_params, _ = ard_optimizer(
176+
ard_rngs = jax.random.split(spec.ard_rng, spec.ard_random_restarts + 1)
177+
best_params, _ = spec.ard_optimizer(
94178
eqx.filter_jit(eqx.filter_vmap(model.setup))(ard_rngs[1:]),
95179
model.loss_with_aux,
96180
ard_rngs[0],
97181
constraints=model.constraints(),
98-
best_n=ensemble_size or 1,
182+
best_n=spec.ensemble_size or 1,
99183
)
100184
best_models = sp.StochasticProcessWithCoroutine(
101-
coroutine=coroutine, params=best_params
185+
coroutine=spec.coroutine, params=best_params
102186
)
103187
# Logging for debugging purposes.
104188
logging.info(
@@ -107,5 +191,145 @@ def train_gp(
107191
predictive = sp.UniformEnsemblePredictive(
108192
eqx.filter_jit(best_models.precompute_predictive)(data)
109193
)
110-
111194
return GPState(predictive=predictive, data=data)
195+
196+
197+
@jax.jit
198+
def _pred_mean(
199+
pred: acquisitions.Predictive, features: types.ModelInput
200+
) -> types.Array:
201+
"""Returns the mean of the predictions from `pred` on `features`.
202+
203+
Workaround while `eqx.filter_jit(pred.pred_with_aux)(features)` is broken
204+
due to a bug in tensorflow probability.
205+
206+
Args:
207+
pred: `Predictive` to predict with.
208+
features: Xs to predict on.
209+
210+
Returns:
211+
Means of the predictions from `pred` on `features`.
212+
"""
213+
return pred.predict_with_aux(features)[0].mean()
214+
215+
216+
def _train_stacked_residual_gp(
217+
base_gp: GPState,
218+
spec: GPTrainingSpec,
219+
data: types.ModelData,
220+
) -> StackedResidualGP:
221+
"""Trains a `StackedResidualGP`.
222+
223+
Completes the following steps in order:
224+
1. Uses `base_gp` to predict on the `data`
225+
2. Computes the residuals from the above predictions
226+
3. Trains a top-level GP on the above residuals
227+
4. Returns a `StackedResidualGP` combining the base GP and newly-trained
228+
GP.
229+
230+
Args:
231+
base_gp: The predictive to use as the base GP for the `StackedResidualGP`
232+
training.
233+
spec: Training spec for the top level GP.
234+
data: Training data for the top level GP.
235+
236+
Returns:
237+
The trained `StackedResidualGP`.
238+
"""
239+
# Compute the residuals of `data` as predicted by `base_gp`
240+
pred_means = _pred_mean(base_gp, data.features)
241+
242+
has_no_padding = ~(
243+
data.features.continuous.is_missing[0]
244+
| data.features.categorical.is_missing[0]
245+
| data.labels.is_missing[0]
246+
)
247+
248+
# Scope this to non-padded predictions only.
249+
pred_means_no_padding = pred_means[has_no_padding]
250+
residuals = (
251+
data.labels.unpad().reshape(pred_means_no_padding.shape)
252+
- pred_means_no_padding
253+
)
254+
255+
# Train on the re-padded residuals
256+
residual_labels = types.PaddedArray.from_array(
257+
array=residuals,
258+
target_shape=data.labels.shape,
259+
fill_value=data.labels.fill_value,
260+
)
261+
data_with_residuals = types.ModelData(
262+
features=data.features, labels=residual_labels
263+
)
264+
265+
top_gp = _train_gp(spec=spec, data=data_with_residuals)
266+
return StackedResidualGP(
267+
predictive=top_gp.predictive,
268+
data=top_gp.data,
269+
base_gp=base_gp,
270+
)
271+
272+
273+
def train_gp(
274+
spec: Union[GPTrainingSpec, Iterable[GPTrainingSpec]],
275+
data: Union[types.ModelData, Iterable[types.ModelData]],
276+
) -> GPState:
277+
"""Trains a Gaussian Process model.
278+
279+
If `spec` contains multiple elements, each will be used to train a
280+
`StackedResidualGP`, sequentially. The last entry will be used to train the
281+
first GP, and then subsequent GPs will be trained on the residuals from the
282+
previous GP. This process completes in reverse order, such that `spec[-1]` is
283+
the first GP trained and `spec[0]` is the last GP trained.
284+
285+
spec[0] and data[0] make up the top-level GP, and spec[1:] and data[1:] define
286+
the priors in context of transfer learning.
287+
288+
Args:
289+
spec: Specification for how to train a GP model. If multiple specs are
290+
provided, transfer learning will train multiple models and combine into a
291+
single GP.
292+
data: Data on which to train GPs. NOTE: `spec` and `data` must be of the
293+
same shape. Trains a GP on `data[i]` with `spec[i]`.
294+
295+
Returns:
296+
The trained GP model.
297+
"""
298+
is_singleton_spec = isinstance(spec, GPTrainingSpec)
299+
is_singleton_data = isinstance(data, types.ModelData)
300+
if is_singleton_spec != is_singleton_data:
301+
raise ValueError(
302+
'`train_gp` expected the shapes of `spec` and `data` to be identical.'
303+
f' Instead got `data` {data} but `spec` {spec}.'
304+
)
305+
306+
if is_singleton_spec and is_singleton_data:
307+
return _train_gp(spec=spec, data=data)
308+
309+
if len(spec) != len(data):
310+
raise ValueError(
311+
'`train_gp` expected the shapes of `spec` and `data` to be identical.'
312+
f' Instead got `spec` of length {len(spec)} but `data` of length'
313+
f' {len(data)}. `spec` was {spec} and `data` was {data}.'
314+
)
315+
316+
curr_gp: Optional[GPState] = None
317+
for curr_spec, curr_data in reversed(list(zip(spec, data))):
318+
if curr_gp is None:
319+
# We are on the first iteration.
320+
curr_gp = _train_gp(spec=curr_spec, data=curr_data)
321+
else:
322+
# Otherwise, we have a base GP to use - the GP trained on the last
323+
# iteration.
324+
curr_gp = _train_stacked_residual_gp(
325+
base_gp=curr_gp,
326+
spec=curr_spec,
327+
data=curr_data,
328+
)
329+
330+
if curr_gp is None:
331+
raise ValueError(
332+
f'Failed to train a GP with provided training spec: {spec} and'
333+
f' data: {data}. `curr_gp` was never updated. This should never happen.'
334+
)
335+
return curr_gp

0 commit comments

Comments
 (0)