Skip to content

Commit 393bd27

Browse files
vizier-teamcopybara-github
authored andcommitted
Internal Change
PiperOrigin-RevId: 552941861
1 parent edc4a34 commit 393bd27

File tree

4 files changed

+293
-59
lines changed

4 files changed

+293
-59
lines changed

vizier/_src/algorithms/designers/gp/gp_models.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -277,13 +277,14 @@ def train_gp(
277277
"""Trains a Gaussian Process model.
278278
279279
If `spec` contains multiple elements, each will be used to train a
280-
`StackedResidualGP`, sequentially. The last entry will be used to train the
280+
`StackedResidualGP`, sequentially. The first entry will be used to train the
281281
first GP, and then subsequent GPs will be trained on the residuals from the
282-
previous GP. This process completes in reverse order, such that `spec[-1]` is
283-
the first GP trained and `spec[0]` is the last GP trained.
282+
previous GP. This process completes in the order that `spec` and `data are
283+
provided, such that `spec[0]` is the first GP trained and `spec[-1]` is the
284+
last GP trained.
284285
285-
spec[0] and data[0] make up the top-level GP, and spec[1:] and data[1:] define
286-
the priors in context of transfer learning.
286+
spec[-1] and data[-1] make up the top-level GP, and spec[:-1] and data[:-1]
287+
define the priors in context of transfer learning.
287288
288289
Args:
289290
spec: Specification for how to train a GP model. If multiple specs are
@@ -314,7 +315,7 @@ def train_gp(
314315
)
315316

316317
curr_gp: Optional[GPState] = None
317-
for curr_spec, curr_data in reversed(list(zip(spec, data))):
318+
for curr_spec, curr_data in zip(spec, data):
318319
if curr_gp is None:
319320
# We are on the first iteration.
320321
curr_gp = _train_gp(spec=curr_spec, data=curr_data)

vizier/_src/algorithms/designers/gp/gp_models_test.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def test_sequential_base_accuracy(
191191

192192
# Combine the good base and the bad top into transfer learning GP.
193193
seq_base_gp = gp_models.train_gp(
194-
[top_spec, base_spec], [top_train_data, base_train_data]
194+
[base_spec, top_spec], [base_train_data, top_train_data]
195195
)
196196

197197
# Create a purposefully-bad GP with `bad_num_samples` for comparison.
@@ -244,8 +244,8 @@ def test_multi_base(
244244
ensemble_size=ensemble_size,
245245
)
246246

247-
train_specs = [top_spec]
248-
train_data = [top_train_data]
247+
train_specs = []
248+
train_data = []
249249

250250
for _ in range(2):
251251
base_spec, base_train_data, _ = _setup_lambda_search(
@@ -257,6 +257,8 @@ def test_multi_base(
257257
)
258258
train_specs.append(base_spec)
259259
train_data.append(base_train_data)
260+
train_specs.append(top_spec)
261+
train_data.append(top_train_data)
260262

261263
seq_base_gp = gp_models.train_gp(train_specs, train_data)
262264

@@ -323,10 +325,10 @@ def test_bad_base_resilience(
323325
# Combine the good base and the bad top into transfer learning GP.
324326
seq_base_gp = gp_models.train_gp(
325327
[
326-
top_spec,
327328
bad_base_spec,
329+
top_spec,
328330
],
329-
[top_train_data, bad_base_train_data],
331+
[bad_base_train_data, top_train_data],
330332
)
331333

332334
# Create a GP on the fake objective with sufficient training data

vizier/_src/algorithms/designers/gp_bandit.py

Lines changed: 141 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,17 @@ class VizierGPBandit(vza.Designer, vza.Predictor):
129129

130130
_last_computed_gp: gp_models.GPState = attr.field(init=False)
131131

132+
# The studies used in transfer learning. Ordered in training order, i.e.
133+
# a GP is trained on `_prior_studies[0]` first, then one is trained on the
134+
# residuals of `_prior_studies[1]` from the GP trained on `_prior_studies[0]`,
135+
# and so on.
136+
_prior_studies: list[vza.CompletedTrials] = attr.field(
137+
factory=list, init=False
138+
)
139+
_incorporated_prior_study_count: int = attr.field(
140+
default=0, kw_only=True, init=False
141+
)
142+
132143
default_acquisition_optimizer_factory = vb.VectorizedOptimizerFactory(
133144
strategy_factory=es.VectorizedEagleStrategyFactory()
134145
)
@@ -205,6 +216,32 @@ def update(
205216
del all_active
206217
self._trials.extend(copy.deepcopy(completed.trials))
207218

219+
def update_priors(self, prior_studies: Sequence[vza.CompletedTrials]) -> None:
220+
"""Updates the list of prior studies for transfer learning.
221+
222+
Each element is treated as a new prior study, and will be stacked in order
223+
received - i.e. the first entry is for the first GP, the second entry is for
224+
the GP trained on the residuals of the first GP, etc.
225+
226+
See section 3.3 of https://dl.acm.org/doi/10.1145/3097983.3098043 for more
227+
information, or see `gp/gp_models.py` and `gp/transfer_learning.py`
228+
229+
Transfer learning is resilient to bad priors.
230+
231+
Multiple calls are permitted. It is up to the caller to ensure
232+
`prior_studies` have a matching `ProblemStatement`, otherwise behavior is
233+
undefined.
234+
235+
TODO: Decide on whether this method should become part of an
236+
interface.
237+
238+
Args:
239+
prior_studies: A list of lists of completed trials, with one list per
240+
prior study. The designer will train a prior GP for each list of prior
241+
trials (for each `CompletedStudy` entry), in the order received.
242+
"""
243+
self._prior_studies.extend(copy.deepcopy(prior_studies))
244+
208245
@property
209246
def _metric_info(self) -> vz.MetricInformation:
210247
return self._problem.metric_information.item()
@@ -286,46 +323,103 @@ def _trials_to_data(self, trials: Sequence[vz.Trial]) -> types.ModelData:
286323
return types.ModelData(model_data.features, labels)
287324

288325
@_experimental_override_allowed
289-
def _train_gp(
326+
def _create_gp_spec(
290327
self, data: types.ModelData, ard_rng: jax.random.KeyArray
291-
) -> gp_models.GPState:
292-
"""Overrideable training of a pre-computed ensemble GP."""
293-
trained_gp = gp_models.train_gp(
294-
spec=gp_models.GPTrainingSpec(
295-
ard_optimizer=self._ard_optimizer,
296-
ard_rng=ard_rng,
297-
coroutine=gp_models.get_vizier_gp_coroutine(
298-
features=data.features, linear_coef=self._linear_coef
299-
),
300-
ensemble_size=self._ensemble_size,
301-
ard_random_restarts=self._ard_random_restarts,
328+
) -> gp_models.GPTrainingSpec:
329+
"""Overrideable creation of a training spec for a GP model."""
330+
return gp_models.GPTrainingSpec(
331+
ard_optimizer=self._ard_optimizer,
332+
ard_rng=ard_rng,
333+
coroutine=gp_models.get_vizier_gp_coroutine(
334+
features=data.features, linear_coef=self._linear_coef
302335
),
303-
data=data,
336+
ensemble_size=self._ensemble_size,
337+
ard_random_restarts=self._ard_random_restarts,
304338
)
305-
return trained_gp
339+
340+
@_experimental_override_allowed
341+
def _train_gp_with_priors(
342+
self,
343+
data: types.ModelData,
344+
ard_rng: jax.random.KeyArray,
345+
priors: Sequence[types.ModelData],
346+
):
347+
"""Trains a transfer-learning-enabled GP with prior studies.
348+
349+
Args:
350+
data: top-level data on which to train a GP.
351+
ard_rng: RNG to do ARD to optimize GP parameters.
352+
priors: Data for each sequential prior to train for transfer learning.
353+
Assumed to be in order of training, i.e. element 0 is priors[0] is the
354+
first GP trained, and priors[1] trains a GP on the residuals of the GP
355+
trained on priors[0], and so on.
356+
357+
Returns:
358+
A trained pre-computed ensemble GP.
359+
"""
360+
ard_rngs = jax.random.split(ard_rng, len(priors) + 1)
361+
362+
# Order `specs` in training order, i.e. `specs[0]` is trained first.
363+
specs = [
364+
self._create_gp_spec(prior_data, ard_rngs[i])
365+
for i, prior_data in enumerate(priors)
366+
]
367+
368+
# Use the last rng for the top level spec.
369+
specs.append(self._create_gp_spec(data, ard_rngs[-1]))
370+
371+
# Order `training_data` in training order, i.e. `training_data[0]` is
372+
# trained first.
373+
training_data = list(priors)
374+
training_data.append(data)
375+
376+
# `train_gp` expects `specs` and `data` in training order, which is how
377+
# they were prepared above.
378+
return gp_models.train_gp(spec=specs, data=training_data)
306379

307380
@profiler.record_runtime
308-
def _update_gp(self, data: types.ModelData) -> gp_models.GPState:
381+
def _update_gp(
382+
self,
383+
data: types.ModelData,
384+
*,
385+
prior_data: Optional[Sequence[types.ModelData]] = None,
386+
) -> gp_models.GPState:
309387
"""Compute the designer's GP and caches the result. No-op without new data.
310388
311389
Args:
312390
data: Data to go into GP.
391+
prior_data: Data to train priors on, in training order.
313392
314393
Returns:
315-
GPBanditState object containing the designer's state.
394+
`GPState` object containing the trained GP.
316395
317396
1. Convert trials to features and labels.
318397
2. Trains a pre-computed ensemble GP.
319398
320399
If no new trials were added since last call, no update will occur.
321400
"""
322-
if len(self._trials) == self._incorporated_trials_count:
323-
# If there's no change in the number of completed trials, don't update
324-
# state. The assumption is that trials can't be removed.
401+
if (
402+
len(self._trials) == self._incorporated_trials_count
403+
and len(self._prior_studies) == self._incorporated_prior_study_count
404+
):
405+
# If there's no change in the number of completed trials or the number of
406+
# priors, don't update state. The assumption is that trials can't be
407+
# removed.
325408
return self._last_computed_gp
326409
self._incorporated_trials_count = len(self._trials)
410+
self._incorporated_prior_study_count = len(self._prior_studies)
327411
self._rng, ard_rng = jax.random.split(self._rng, 2)
328-
self._last_computed_gp = self._train_gp(data=data, ard_rng=ard_rng)
412+
413+
if not prior_data:
414+
self._last_computed_gp = gp_models.train_gp(
415+
spec=self._create_gp_spec(data, ard_rng),
416+
data=data,
417+
)
418+
else:
419+
self._last_computed_gp = self._train_gp_with_priors(
420+
data=data, ard_rng=ard_rng, priors=prior_data
421+
)
422+
329423
return self._last_computed_gp
330424

331425
@_experimental_override_allowed
@@ -379,6 +473,29 @@ def _optimize_acquisition(
379473
best_candidates, self._converter
380474
) # [N, D]
381475

476+
@profiler.record_runtime
477+
@_experimental_override_allowed
478+
def _generate_data(
479+
self,
480+
) -> tuple[types.ModelData, Optional[list[types.ModelData]]]:
481+
"""Converts trials to top-level and prior training data."""
482+
prior_data: Optional[list[types.ModelData]] = None
483+
if self._prior_studies:
484+
prior_data = [
485+
self._trials_to_data(prior_study.trials)
486+
for prior_study in self._prior_studies
487+
]
488+
489+
# The top level data must be converted last - because `_output_warper`
490+
# depends on the support points that were supplied to it in `warp` to
491+
# `unwarp` labels. It stores these support points each time `warp` is
492+
# called, so the last `warp` call dictates the support points used in
493+
# `unwarp`. Therefore, since we want to `unwarp` the predictions based off
494+
# the current (top) study rather than any prior study, we need to call
495+
# `warp` on the current study last.
496+
data = self._trials_to_data(self._trials)
497+
return data, prior_data
498+
382499
@profiler.record_runtime
383500
def suggest(self, count: int = 1) -> Sequence[vz.TrialSuggestion]:
384501
logging.info('Suggest called with count=%d', count)
@@ -393,8 +510,8 @@ def suggest(self, count: int = 1) -> Sequence[vz.TrialSuggestion]:
393510

394511
suggest_start_time = datetime.datetime.now()
395512
logging.info('Updating the designer state based on trials...')
396-
data = self._trials_to_data(self._trials)
397-
gp = self._update_gp(data)
513+
data, prior_data = self._generate_data()
514+
gp = self._update_gp(data, prior_data=prior_data)
398515

399516
# Define acquisition function.
400517
scoring_fn = self._scoring_function_factory(
@@ -437,7 +554,8 @@ def sample(
437554
if not trials:
438555
return np.zeros((num_samples, 0))
439556

440-
gp = self._update_gp(self._trials_to_data(self._trials))
557+
data, prior_data = self._generate_data()
558+
gp = self._update_gp(data, prior_data=prior_data)
441559
xs = self._converter.to_features(trials)
442560
xs = types.ModelInput(
443561
continuous=xs.continuous.replace_fill_value(0.0),

0 commit comments

Comments
 (0)