@@ -129,6 +129,17 @@ class VizierGPBandit(vza.Designer, vza.Predictor):
129129
130130 _last_computed_gp : gp_models .GPState = attr .field (init = False )
131131
132+ # The studies used in transfer learning. Ordered in training order, i.e.
133+ # a GP is trained on `_prior_studies[0]` first, then one is trained on the
134+ # residuals of `_prior_studies[1]` from the GP trained on `_prior_studies[0]`,
135+ # and so on.
136+ _prior_studies : list [vza .CompletedTrials ] = attr .field (
137+ factory = list , init = False
138+ )
139+ _incorporated_prior_study_count : int = attr .field (
140+ default = 0 , kw_only = True , init = False
141+ )
142+
132143 default_acquisition_optimizer_factory = vb .VectorizedOptimizerFactory (
133144 strategy_factory = es .VectorizedEagleStrategyFactory ()
134145 )
@@ -205,6 +216,32 @@ def update(
205216 del all_active
206217 self ._trials .extend (copy .deepcopy (completed .trials ))
207218
219+ def update_priors (self , prior_studies : Sequence [vza .CompletedTrials ]) -> None :
220+ """Updates the list of prior studies for transfer learning.
221+
222+ Each element is treated as a new prior study, and will be stacked in order
223+ received - i.e. the first entry is for the first GP, the second entry is for
224+ the GP trained on the residuals of the first GP, etc.
225+
226+ See section 3.3 of https://dl.acm.org/doi/10.1145/3097983.3098043 for more
227+ information, or see `gp/gp_models.py` and `gp/transfer_learning.py`
228+
229+ Transfer learning is resilient to bad priors.
230+
231+ Multiple calls are permitted. It is up to the caller to ensure
232+ `prior_studies` have a matching `ProblemStatement`, otherwise behavior is
233+ undefined.
234+
235+ TODO: Decide on whether this method should become part of an
236+ interface.
237+
238+ Args:
239+ prior_studies: A list of lists of completed trials, with one list per
240+ prior study. The designer will train a prior GP for each list of prior
241+ trials (for each `CompletedStudy` entry), in the order received.
242+ """
243+ self ._prior_studies .extend (copy .deepcopy (prior_studies ))
244+
208245 @property
209246 def _metric_info (self ) -> vz .MetricInformation :
210247 return self ._problem .metric_information .item ()
@@ -286,46 +323,103 @@ def _trials_to_data(self, trials: Sequence[vz.Trial]) -> types.ModelData:
286323 return types .ModelData (model_data .features , labels )
287324
288325 @_experimental_override_allowed
289- def _train_gp (
326+ def _create_gp_spec (
290327 self , data : types .ModelData , ard_rng : jax .random .KeyArray
291- ) -> gp_models .GPState :
292- """Overrideable training of a pre-computed ensemble GP."""
293- trained_gp = gp_models .train_gp (
294- spec = gp_models .GPTrainingSpec (
295- ard_optimizer = self ._ard_optimizer ,
296- ard_rng = ard_rng ,
297- coroutine = gp_models .get_vizier_gp_coroutine (
298- features = data .features , linear_coef = self ._linear_coef
299- ),
300- ensemble_size = self ._ensemble_size ,
301- ard_random_restarts = self ._ard_random_restarts ,
328+ ) -> gp_models .GPTrainingSpec :
329+ """Overrideable creation of a training spec for a GP model."""
330+ return gp_models .GPTrainingSpec (
331+ ard_optimizer = self ._ard_optimizer ,
332+ ard_rng = ard_rng ,
333+ coroutine = gp_models .get_vizier_gp_coroutine (
334+ features = data .features , linear_coef = self ._linear_coef
302335 ),
303- data = data ,
336+ ensemble_size = self ._ensemble_size ,
337+ ard_random_restarts = self ._ard_random_restarts ,
304338 )
305- return trained_gp
339+
340+ @_experimental_override_allowed
341+ def _train_gp_with_priors (
342+ self ,
343+ data : types .ModelData ,
344+ ard_rng : jax .random .KeyArray ,
345+ priors : Sequence [types .ModelData ],
346+ ):
347+ """Trains a transfer-learning-enabled GP with prior studies.
348+
349+ Args:
350+ data: top-level data on which to train a GP.
351+ ard_rng: RNG to do ARD to optimize GP parameters.
352+ priors: Data for each sequential prior to train for transfer learning.
353+ Assumed to be in order of training, i.e. element 0 is priors[0] is the
354+ first GP trained, and priors[1] trains a GP on the residuals of the GP
355+ trained on priors[0], and so on.
356+
357+ Returns:
358+ A trained pre-computed ensemble GP.
359+ """
360+ ard_rngs = jax .random .split (ard_rng , len (priors ) + 1 )
361+
362+ # Order `specs` in training order, i.e. `specs[0]` is trained first.
363+ specs = [
364+ self ._create_gp_spec (prior_data , ard_rngs [i ])
365+ for i , prior_data in enumerate (priors )
366+ ]
367+
368+ # Use the last rng for the top level spec.
369+ specs .append (self ._create_gp_spec (data , ard_rngs [- 1 ]))
370+
371+ # Order `training_data` in training order, i.e. `training_data[0]` is
372+ # trained first.
373+ training_data = list (priors )
374+ training_data .append (data )
375+
376+ # `train_gp` expects `specs` and `data` in training order, which is how
377+ # they were prepared above.
378+ return gp_models .train_gp (spec = specs , data = training_data )
306379
307380 @profiler .record_runtime
308- def _update_gp (self , data : types .ModelData ) -> gp_models .GPState :
381+ def _update_gp (
382+ self ,
383+ data : types .ModelData ,
384+ * ,
385+ prior_data : Optional [Sequence [types .ModelData ]] = None ,
386+ ) -> gp_models .GPState :
309387 """Compute the designer's GP and caches the result. No-op without new data.
310388
311389 Args:
312390 data: Data to go into GP.
391+ prior_data: Data to train priors on, in training order.
313392
314393 Returns:
315- GPBanditState object containing the designer's state .
394+ `GPState` object containing the trained GP .
316395
317396 1. Convert trials to features and labels.
318397 2. Trains a pre-computed ensemble GP.
319398
320399 If no new trials were added since last call, no update will occur.
321400 """
322- if len (self ._trials ) == self ._incorporated_trials_count :
323- # If there's no change in the number of completed trials, don't update
324- # state. The assumption is that trials can't be removed.
401+ if (
402+ len (self ._trials ) == self ._incorporated_trials_count
403+ and len (self ._prior_studies ) == self ._incorporated_prior_study_count
404+ ):
405+ # If there's no change in the number of completed trials or the number of
406+ # priors, don't update state. The assumption is that trials can't be
407+ # removed.
325408 return self ._last_computed_gp
326409 self ._incorporated_trials_count = len (self ._trials )
410+ self ._incorporated_prior_study_count = len (self ._prior_studies )
327411 self ._rng , ard_rng = jax .random .split (self ._rng , 2 )
328- self ._last_computed_gp = self ._train_gp (data = data , ard_rng = ard_rng )
412+
413+ if not prior_data :
414+ self ._last_computed_gp = gp_models .train_gp (
415+ spec = self ._create_gp_spec (data , ard_rng ),
416+ data = data ,
417+ )
418+ else :
419+ self ._last_computed_gp = self ._train_gp_with_priors (
420+ data = data , ard_rng = ard_rng , priors = prior_data
421+ )
422+
329423 return self ._last_computed_gp
330424
331425 @_experimental_override_allowed
@@ -379,6 +473,29 @@ def _optimize_acquisition(
379473 best_candidates , self ._converter
380474 ) # [N, D]
381475
476+ @profiler .record_runtime
477+ @_experimental_override_allowed
478+ def _generate_data (
479+ self ,
480+ ) -> tuple [types .ModelData , Optional [list [types .ModelData ]]]:
481+ """Converts trials to top-level and prior training data."""
482+ prior_data : Optional [list [types .ModelData ]] = None
483+ if self ._prior_studies :
484+ prior_data = [
485+ self ._trials_to_data (prior_study .trials )
486+ for prior_study in self ._prior_studies
487+ ]
488+
489+ # The top level data must be converted last - because `_output_warper`
490+ # depends on the support points that were supplied to it in `warp` to
491+ # `unwarp` labels. It stores these support points each time `warp` is
492+ # called, so the last `warp` call dictates the support points used in
493+ # `unwarp`. Therefore, since we want to `unwarp` the predictions based off
494+ # the current (top) study rather than any prior study, we need to call
495+ # `warp` on the current study last.
496+ data = self ._trials_to_data (self ._trials )
497+ return data , prior_data
498+
382499 @profiler .record_runtime
383500 def suggest (self , count : int = 1 ) -> Sequence [vz .TrialSuggestion ]:
384501 logging .info ('Suggest called with count=%d' , count )
@@ -393,8 +510,8 @@ def suggest(self, count: int = 1) -> Sequence[vz.TrialSuggestion]:
393510
394511 suggest_start_time = datetime .datetime .now ()
395512 logging .info ('Updating the designer state based on trials...' )
396- data = self ._trials_to_data ( self . _trials )
397- gp = self ._update_gp (data )
513+ data , prior_data = self ._generate_data ( )
514+ gp = self ._update_gp (data , prior_data = prior_data )
398515
399516 # Define acquisition function.
400517 scoring_fn = self ._scoring_function_factory (
@@ -437,7 +554,8 @@ def sample(
437554 if not trials :
438555 return np .zeros ((num_samples , 0 ))
439556
440- gp = self ._update_gp (self ._trials_to_data (self ._trials ))
557+ data , prior_data = self ._generate_data ()
558+ gp = self ._update_gp (data , prior_data = prior_data )
441559 xs = self ._converter .to_features (trials )
442560 xs = types .ModelInput (
443561 continuous = xs .continuous .replace_fill_value (0.0 ),
0 commit comments