diff --git a/src/poli_baselines/solvers/__init__.py b/src/poli_baselines/solvers/__init__.py index 590746c..69a024c 100644 --- a/src/poli_baselines/solvers/__init__.py +++ b/src/poli_baselines/solvers/__init__.py @@ -1,17 +1,17 @@ -# from .simple.random_mutation import RandomMutation -# from .simple.continuous_random_mutation import ContinuousRandomMutation -# from .simple.genetic_algorithm import FixedLengthGeneticAlgorithm - -# from .bayesian_optimization.vanilla_bayesian_optimization import ( -# VanillaBayesianOptimization, -# ) -# from .bayesian_optimization.line_bayesian_optimization import LineBO -# from .bayesian_optimization.saas_bayesian_optimization import SAASBO -# from .bayesian_optimization.baxus import BAxUS - -# from .bayesian_optimization.latent_space_bayesian_optimization import ( -# LatentSpaceBayesianOptimization, -# ) - -# from .evolutionary_strategies.cma_es import CMA_ES -# from .multi_objective.nsga_ii import DiscreteNSGAII +# from .simple.random_mutation import RandomMutation +# from .simple.continuous_random_mutation import ContinuousRandomMutation +# from .simple.genetic_algorithm import FixedLengthGeneticAlgorithm + +# from .bayesian_optimization.vanilla_bayesian_optimization import ( +# VanillaBayesianOptimization, +# ) +# from .bayesian_optimization.line_bayesian_optimization import LineBO +# from .bayesian_optimization.saas_bayesian_optimization import SAASBO +# from .bayesian_optimization.baxus import BAxUS + +# from .bayesian_optimization.latent_space_bayesian_optimization import ( +# LatentSpaceBayesianOptimization, +# ) + +# from .evolutionary_strategies.cma_es import CMA_ES +# from .multi_objective.nsga_ii import DiscreteNSGAII diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/__init__.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/amortized_bo_wrapper.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/amortized_bo_wrapper.py new file mode 100644 index 0000000..985d852 --- /dev/null +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/amortized_bo_wrapper.py @@ -0,0 +1,44 @@ +import numpy as np +from poli.core.abstract_black_box import AbstractBlackBox + +from poli_baselines.core.step_by_step_solver import StepByStepSolver +from poli_baselines.solvers.bayesian_optimization.amortized.data import ( + samples_from_arrays, + Population, +) +from poli_baselines.solvers.bayesian_optimization.amortized.deep_evolution_solver import ( + MutationPredictorSolver, +) +from poli_baselines.solvers.bayesian_optimization.amortized.domains import ( + FixedLengthDiscreteDomain, + Vocabulary, +) + + +class AmortizedBOWrapper(StepByStepSolver): + def __init__(self, black_box: AbstractBlackBox, x0: np.ndarray, y0: np.ndarray): + super().__init__(black_box, x0, y0) + self.problem_info = black_box.get_black_box_info() + alphabet = self.problem_info.get_alphabet() + if not self.problem_info.sequences_are_aligned(): + alphabet = alphabet + [self.problem_info.get_padding_token()] + self.domain = FixedLengthDiscreteDomain( + vocab=Vocabulary(alphabet), + length=x0.shape[1], + ) + self.solver = MutationPredictorSolver( + domain=self.domain, + initialize_dataset_fn=lambda *args, **kwargs: self.domain.encode(x0), + ) + + def next_candidate(self) -> np.ndarray: + samples = samples_from_arrays( + structures=self.domain.encode(self.x0.tolist()), rewards=self.y0.tolist() + ) + x = self.solver.propose(num_samples=1, population=Population(samples)) + s = list(self.domain.decode(x)[0]) + if not self.problem_info.sequences_are_aligned(): + s = s + [self.problem_info.get_padding_token()] * ( + self.problem_info.get_max_sequence_length() - len(s) + ) + return np.array([s]) diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/base_solver.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/base_solver.py new file mode 100644 index 0000000..8138056 --- /dev/null +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/base_solver.py @@ -0,0 +1,82 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Solver base class.""" + +import abc + +from absl import logging + +from poli_baselines.solvers.bayesian_optimization.amortized import utils + + +class BaseSolver(abc.ABC): + """Solver base class.""" + + def __init__( + self, domain, random_state=None, name=None, log_level=logging.INFO, **kwargs + ): + """Creates an instance of this class. + + Args: + domain: An instance of a `Domain`. + random_state: An instance of or integer seed to build a + `np.random.RandomState`. + name: The name of the solver. If `None`, will use the class name. + log_level: The logging level of the solver-specific logger. + -2=ERROR, -1=WARN, 0=INFO, 1=DEBUG. + **kwargs: Named arguments stored in `self.cfg`. + """ + self._domain = domain + self._name = name or self.__class__.__name__ + self._random_state = utils.get_random_state(random_state) + self._log = utils.get_logger(self._name, level=log_level) + + cfg = utils.Config(self._config()) + cfg.update(kwargs) + self.cfg = cfg + + def _config(self): + return {} + + @property + def domain(self): + """Return the optimization domain.""" + return self._domain + + @property + def name(self): + """Returns the solver name.""" + return self._name + + def __str__(self): + return self._name + + @abc.abstractmethod + def propose(self, num_samples, population=None, pending_samples=None, counter=0): + """Proposes num_samples from `self.domain`. + + Args: + num_samples: The number of samples to return. + population: A `Population` of samples or None if the population is empty. + pending_samples: A list of structures without reward that were already + proposed. + counter: The number of times `propose` has been called with the same + `population`. Can be used by solver to avoid repeated computations on + the same `population`, e.g. updating a model. + + Returns: + `num_samples` structures from the domain. + """ diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/data.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/data.py new file mode 100644 index 0000000..2b18658 --- /dev/null +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/data.py @@ -0,0 +1,706 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Population class for keeping track of structures and rewards.""" + +import collections +import itertools +import typing + +import attr +import numpy as np +import pandas as pd +import tensorflow.compat.v1 as tf + +from poli_baselines.solvers.bayesian_optimization.amortized import domains +from poli_baselines.solvers.bayesian_optimization.amortized import utils + +# DatasetSample defines the structure of samples in tf.data.Datasets for +# pre-training solvers, whereas Samples (see below) defines the structure of +# samples in a Population object. +DatasetSample = collections.namedtuple("DatasetSample", ["structure", "reward"]) + + +def dataset_to_population(population_or_tf_dataset): + """Converts a TF dataset to a Population if it is not already a Population.""" + if isinstance(population_or_tf_dataset, Population): + return population_or_tf_dataset + else: + return Population.from_dataset(population_or_tf_dataset) + + +def serialize_structure(structure): + """Converts a structure to a string.""" + structure = np.asarray(structure) + dim = len(structure.shape) + if dim != 1: + raise NotImplementedError(f"`structure` must be 1d but is {dim}d!") + return domains.SEP_TOKEN.join(str(token) for token in structure) + + +def serialize_structures(structures, **kwargs): + """Converts a list of structures to a list of strings.""" + return [serialize_structure(structure, **kwargs) for structure in structures] + + +def deserialize_structure(serialized_structure, dtype=np.int32): + """Converts a string to a structure. + + Args: + serialized_structure: A structure produced by `serialize_structure`. + dtype: The data type of the output numpy array. + + Returns: + A numpy array with `dtype`. + """ + return np.asarray( + [token for token in serialized_structure.split(domains.SEP_TOKEN)], dtype=dtype + ) + + +def deserialize_structures(structures, **kwargs): + """Converts a list of strings to a list of structures. + + Args: + structures: A list of strings produced by `serialize_structures`. + **kwargs: Named arguments passed to `deserialize_structure`. + + Returns: + A list of numpy array. + """ + return [deserialize_structure(structure, **kwargs) for structure in structures] + + +def serialize_population_frame(frame, inplace=False, domain=None): + """Serializes a population `pd.DataFrame` for representing it as plain text. + + Args: + frame: A `pd.DataFrame` produced by `Population.to_frame`. + inplace: Whether to serialize `frame` inplace instead of creating a copy. + domain: An optional domain for decoding structures. If provided, will + add a column `decoded_structure` with the serialized decoded structures. + + Returns: + A `pd.DataFrame` with serialized structures. + """ + if not inplace: + frame = frame.copy() + if domain: + frame["decoded_structure"] = serialize_structures( + domain.decode(frame["structure"], as_str=False) + ) + frame["structure"] = serialize_structures(frame["structure"]) + return frame + + +def deserialize_population_frame(frame, inplace=False): + """Deserializes a population `pd.DataFrame` from plain text. + + Args: + frame: A `pd.DataFrame` produced by `serialize_population_frame`. + inplace: Whether to deserialize `frame` inplace instead of creating a copy. + + Returns: + A `pd.DataFrame` with deserialized structures. + """ + if not inplace: + frame = frame.copy() + frame["structure"] = deserialize_structures(frame["structure"]) + if "decoded_structure" in frame.columns: + frame["decoded_structure"] = deserialize_structures( + frame["decoded_structure"], dtype=str + ) + return frame + + +def population_frame_to_csv( + frame, path_or_buf=None, domain=None, index=False, **kwargs +): + """Converts a population `pd.DataFrame` to a csv table. + + Args: + frame: A `pd.DataFrame` produced by `Population.to_frame`. + path_or_buf: File path or object. If `None`, the result is returned as a + string. Otherwise write the csv table to that file. + domain: A optional domain for decoding structures. + index: Whether to store the index of `frame`. + **kwargs: Named arguments passed to `frame.to_csv`. + + Returns: + If `path_or_buf` is `None`, returns the resulting csv format as a + string. Otherwise returns `None`. + """ + if frame.empty: + raise ValueError("Cannot write empty population frame to CSV file!") + frame = serialize_population_frame(frame, domain=domain) + return frame.to_csv(path_or_buf, index=index, **kwargs) + + +def population_frame_from_csv(path_or_buf, **kwargs): + """Reads a population `pd.DataFrame` from a file. + + Args: + path_or_buf: A string path of file buffer. + **kwargs: Named arguments passed to `pd.read_csv`. + + Returns: + A `pd.DataFrame`. + """ + frame = pd.read_csv(path_or_buf, dtype={"metadata": object}, **kwargs) + frame = deserialize_population_frame(frame) + return frame + + +def subtract_mean_batch_reward(population): + """Returns new Population where each batch has mean-zero rewards.""" + df = population.to_frame() + mean_dict = df.groupby("batch_index").reward.mean().to_dict() + + def reward_for_sample(sample): + return sample.reward - mean_dict[sample.batch_index] + + shifted_samples = [ + sample.copy(reward=reward_for_sample(sample)) for sample in population + ] + return Population(shifted_samples) + + +def _to_immutable_array(array): + to_return = np.array(array) + to_return.setflags(write=False) + return to_return + + +class _NumericConverter(object): + """Helper class for converting values to a numeric data type.""" + + def __init__(self, dtype, min_value=None, max_value=None): + self._dtype = dtype + self._min_value = min_value + self._max_value = max_value + + def __call__(self, value): + """Validates and converts `value` to `self._dtype`.""" + if value is None: + return value + if not np.isreal(value): + raise TypeError("%s is not numeric!" % value) + value = self._dtype(value) + if self._min_value is not None and value < self._min_value: + raise TypeError("%f < %f" % (value, self._min_value)) + if self._max_value is not None and value > self._max_value: + raise TypeError("%f > %f" % (value, self._max_value)) + return value + + +@attr.s( + frozen=True, # Make it immutable. + slots=True, # Improve memory overhead. + eq=False, # Because we override __eq__. +) +class Sample(object): + """Immutable container for a structure, reward, and additional data. + + Attributes: + key: (str) A unique identifier of the sample. If not provided, will create + a unique identifier using `utils.create_unique_id`. + structure: (np.ndarray) The structure. + reward: (float, optional) The reward. + batch_index: (int, optional) The batch index within the population. + infeasible: Whether the sample was marked as infeasible by the evaluator. + metadata: (any type, optional) Additional meta-data. + """ + + structure: np.ndarray = attr.ib( + factory=np.array, converter=_to_immutable_array + ) # pytype: disable=wrong-arg-types # attr-stubs + reward: float = attr.ib(converter=_NumericConverter(float), default=None) + batch_index: int = attr.ib( + converter=_NumericConverter(int, min_value=0), default=None + ) + infeasible: bool = attr.ib( + default=False, validator=attr.validators.instance_of(bool) + ) + key: str = attr.ib( + factory=utils.create_unique_id, + validator=attr.validators.optional(attr.validators.instance_of(str)), + ) + metadata: typing.Dict[str, typing.Any] = attr.ib(default=None) + + def __eq__(self, other): + """Compares samples irrespective of their key.""" + return ( + self.reward == other.reward + and self.batch_index == other.batch_index + and self.infeasible == other.infeasible + and self.metadata == other.metadata + and np.array_equal(self.structure, other.structure) + ) + + def equal(self, other): + """Compares samples including their key.""" + return self == other and self.key == other.key + + def to_tfexample(self): + """Converts a Sample to a tf.Example.""" + features = dict( + structure=tf.train.Feature( + int64_list=tf.train.Int64List(value=self.structure) + ), + reward=tf.train.Feature(float_list=tf.train.FloatList(value=[self.reward])), + batch_index=tf.train.Feature( + int64_list=tf.train.Int64List(value=[self.batch_index]) + ), + ) + return tf.train.Example(features=tf.train.Features(feature=features)) + + def to_dict(self, dict_factory=collections.OrderedDict): + """Returns a dictionary from field names to values. + + Args: + dict_factory: A class that implements a dict factory method. + + Returns: + A dict of type `dict_factory` + """ + return attr.asdict(self, dict_factory=dict_factory) + + def copy(self, new_key=True, **kwargs): + """Returns a copy of this Sample with values overridden by **kwargs.""" + if new_key: + kwargs["key"] = utils.create_unique_id() + return attr.evolve(self, **kwargs) + + +def samples_from_arrays(structures, rewards=None, batch_index=None, metadata=None): + """Makes a generator of Samples from fields. + + Args: + structures: Iterable of structures (1-D np array or list). + rewards: Iterable of float rewards. If None, the corresponding Samples are + given each given a reward of None. + batch_index: Either an int, in which case all Samples created by this + function will be given this batch_index or an iterable of ints for each + corresponding structure. + metadata: Metadata to store in the Sample. + + Yields: + A generator of Samples + """ + structures = utils.to_array(structures) + + if metadata is None: + metadata = [None] * len(structures) + + if rewards is None: + rewards = [None] * len(structures) + else: + rewards = utils.to_array(rewards) + + if len(structures) != len(rewards): + raise ValueError( + "Structures and rewards must be same length. Are %s and %s" + % (len(structures), len(rewards)) + ) + if len(metadata) != len(rewards): + raise ValueError( + "Metadata and rewards must be same length. Are %s and %s" + % (len(metadata), len(rewards)) + ) + + if batch_index is None: + batch_index = 0 + if isinstance(batch_index, int): + batch_index = [batch_index] * len(structures) + + for structure, reward, batch_index, meta in zip( + structures, rewards, batch_index, metadata + ): + yield Sample( + structure=structure, reward=reward, batch_index=batch_index, metadata=meta + ) + + +def parse_tf_example(example_proto): + """Converts tf.Example proto to dict of Tensors. + + Args: + example_proto: A raw tf.Example proto. + Returns: + A dict of Tensors with fields structure, reward, and batch_index. + """ + + feature_description = dict( + structure=tf.FixedLenSequenceFeature((), tf.int64, allow_missing=True), + reward=tf.FixedLenFeature([1], tf.float32), + batch_index=tf.FixedLenFeature([1], tf.int64), + ) + + return tf.io.parse_single_example( + serialized=example_proto, features=feature_description + ) + + +class Population(object): + """Data structure for storing Samples.""" + + def __init__(self, samples=None): + """Construct a Population. + + Args: + samples: An iterable of Samples + """ + self._samples = collections.OrderedDict() + self._batch_to_sample_keys = collections.defaultdict(list) + + if samples is not None: + self.add_samples(samples) + + def __str__(self): + if self.empty: + return "" + return "" % ( + len(self), + dict(self.best().to_dict()), + ) + + def __len__(self): + return len(self._samples) + + def __iter__(self): + return self._samples.values().__iter__() + + def __eq__(self, other): + if not isinstance(other, Population): + raise ValueError( + "Cannot compare equality with an object of " + "type %s" % (str(type(other))) + ) + + return len(self) == len(other) and all(s1 == s2 for s1, s2 in zip(self, other)) + + def __add__(self, other): + """Adds samples to this population and returns a new Population.""" + return Population(itertools.chain(self, other)) + + def __getitem__(self, key_or_index): + if isinstance(key_or_index, str): + return self._samples[key_or_index] + else: + return list(self._samples.values())[key_or_index] + + def __contains__(self, key): + if isinstance(key, Sample): + key = key.key + return key in self._samples + + def copy(self): + """Copies the population.""" + return Population(self.samples) + + @property + def samples(self): + """Returns the population Samples as a list.""" + return list(self._samples.values()) + + def add_sample(self, sample): + """Add copy of sample to population.""" + if sample.key in self._samples: + raise ValueError("Sample with key %s already exists in the population!") + self._samples[sample.key] = sample + batch_idx = sample.batch_index + self._batch_to_sample_keys[batch_idx].append(sample.key) + + def add_samples(self, samples): + """Convenience method for adding multiple samples.""" + for sample in samples: + self.add_sample(sample) + + @property + def empty(self): + return not self + + @property + def batch_indices(self): + """Returns a sorted list of unique batch indices.""" + return sorted(self._batch_to_sample_keys) + + @property + def max_batch_index(self): + """Returns the maximum batch index.""" + if self.empty: + raise ValueError("Population empty!") + return self.batch_indices[-1] + + @property + def current_batch_index(self): + """Return the maximum batch index or -1 if the population is empty.""" + return -1 if self.empty else self.max_batch_index + + def get_batches(self, batch_indices, exclude=False, validate=True): + """ "Extracts certain batches from the population. + + Ignores batches that do not exist in the population. To validate if a + certain batch exists use `batch_index in population.batch_indices`. + + Args: + batch_indices: An integer, iterable of integers, or `slice` object + for selecting batches. + exclude: If true, will return all batches but the ones that are selected. + validate: Whether to raise an exception if a batch index is invalid + instead of ignoring. + + Returns: + A `Population` with the selected batches. + """ + batch_indices = utils.get_indices( + self.batch_indices, batch_indices, exclude=exclude, validate=validate + ) + sample_idxs = [] + for batch_index in batch_indices: + sample_idxs.extend(self._batch_to_sample_keys.get(batch_index, [])) + samples = [self[idx] for idx in sample_idxs] + return Population(samples) + + def get_batch(self, *args, **kwargs): + return self.get_batches(*args, **kwargs) + + def get_last_batch(self, **kwargs): + """Returns the last batch from the population.""" + return self.get_batch(-1, **kwargs) + + def get_last_batches(self, n=1, **kwargs): + """Selects the last n batches.""" + return self.get_batches(self.batch_indices[-n:], **kwargs) + + def to_structures_and_rewards(self): + """Return (list of structures, list of rewards) in the Population.""" + structures = [] + rewards = [] + for sample in self: + structures.append(sample.structure) + rewards.append(sample.reward) + return structures, rewards + + @property + def structures(self): + """Returns the structure of all samples in the population.""" + return [sample.structure for sample in self] + + @property + def rewards(self): + """Returns the reward of all samples in the population.""" + return [sample.reward for sample in self] + + @staticmethod + def from_arrays(structures, rewards=None, batch_index=0, metadata=None): + """Creates Population from the specified fields. + + Args: + structures: Iterable of structures (1-D np array or list). + rewards: Iterable of float rewards. If None, the corresponding Samples are + given each given a reward of None. + batch_index: Either an int, in which case all Samples created by this + function will be given this batch_index or an iterable of ints for each + corresponding structure. + metadata: Metadata to store in the Samples. + + Returns: + A Population. + """ + samples = samples_from_arrays(structures, rewards, batch_index, metadata) + return Population(samples) + + def add_batch(self, structures, rewards=None, batch_index=None, metadata=None): + """Adds a batch of samples to the Population. + + Args: + structures: Iterable of structures (1-D np array or list). + rewards: Iterable of rewards. + batch_index: Either an int, in which case all Samples created by this + function will be given this batch_index or an iterable of ints for each + corresponding structure. If `None`, uses `self.current_batch + 1`. + metadata: Metadata to store in the Samples. + """ + if batch_index is None: + batch_index = self.current_batch_index + 1 + samples = samples_from_arrays(structures, rewards, batch_index, metadata) + self.add_samples(samples) + + def head(self, n): + """Returns new Population containing first n samples.""" + + return Population(self.samples[:n]) + + @staticmethod + def from_dataset(dataset): + """Converts dataset of DatasetSample to Population.""" + samples = utils.arrays_from_dataset(dataset) + return Population.from_arrays(samples.structure, samples.reward) + + def to_dataset(self): + """Converts the population to a `tf.data.Dataset` with `DatasetSample`s.""" + structures, rewards = self.to_structures_and_rewards() + return utils.dataset_from_tensors( + DatasetSample(structure=structures, reward=rewards) + ) + + @staticmethod + def from_tfrecord(filename): + """Reads Population from tfrecord file.""" + raw_dataset = tf.data.TFRecordDataset([filename]) + parsed_dataset = raw_dataset.map(parse_tf_example) + + def _record_to_dict(record): + mapping = {key: utils.to_array(value) for key, value in record.items()} + if "batch_index" in mapping: + mapping["batch_index"] = int(mapping["batch_index"]) + return mapping + + return Population(Sample(**_record_to_dict(r)) for r in parsed_dataset) + + def to_tfrecord(self, filename): + """Writes Population to tfrecord file.""" + + with tf.python_io.TFRecordWriter(filename) as writer: + for sample in self.samples: + writer.write(sample.to_tfexample().SerializeToString()) + + def to_frame(self): + """Converts a `Population` to a `pd.DataFrame`.""" + records = [sample.to_dict() for sample in self.samples] + return pd.DataFrame.from_records(records) + + @classmethod + def from_frame(cls, frame): + """Converts a `pd.DataFrame` to a `Population`.""" + if frame.empty: + return cls() + population = cls() + for _, row in frame.iterrows(): + sample = Sample(**row.to_dict()) + population.add_sample(sample) + return population + + def to_csv(self, path, domain=None): + """Stores a population to a CSV file. + + Args: + path: The output file path. + domain: An optional `domains.Domain`. If provided, will also store + decoded structures in the CSV file. + """ + population_frame_to_csv(self.to_frame(), path, domain=domain) + + @classmethod + def from_csv(cls, path): + """Restores a population from a CSV file. + + Args: + path: The CSV file path. + + Returns: + An instance of a `Population`. + """ + return cls.from_frame( + population_frame_from_csv(path).drop( + columns=["decoded_structure"], errors="ignore" + ) + ) + + def best_n(self, n=1, q=None, discard_duplicates=False, blacklist=None): + """Returns the best n samples. + + Note that ties are broken deterministically. + + Args: + n: Max number to return + q: A float in (0, 1) corresponding to the minimum quantile for selecting + samples. If provided, `n` is ignored and samples with a reward >= + this quantile are selected. + discard_duplicates: If True, when several samples have the same structure, + return only one of them (the selected one is unspecified). + blacklist: Iterable of structures that should be excluded. + + Returns: + Population containing the best n Samples, sorted in decreasing order of + reward (output[0] is the best). Returns less than n if there are fewer + than n Samples in the population. + """ + if self.empty: + raise ValueError("Population empty.") + + samples = self.samples + if blacklist: + samples = self._filter(samples, blacklist) + + # are unique. + if discard_duplicates and len(samples) > 1: + samples = utils.deduplicate_samples(samples) + + samples = sorted(samples, key=lambda sample: sample.reward, reverse=True) + if q is not None: + q_value = np.quantile([sample.reward for sample in samples], q) + return Population( + [sample for sample in samples if sample.reward >= q_value] + ) + else: + return Population(samples[:n]) + + def best(self, blacklist=None): + return self.best_n(1, blacklist=blacklist).samples[0] + + def _filter(self, samples, blacklist): + blacklist = set(utils.hash_structure(structure) for structure in blacklist) + return Population( + sample + for sample in samples + if utils.hash_structure(sample.structure) not in blacklist + ) + + def contains_structure(self, structure): + return self.contains_structures([structure])[0] + + # TODO(dbelanger): this has computational overhead because it hashes the + # entire population. If this function is being called often, the set of hashed + # structures should be built up incrementally as elements are added to the + # population. Right now, the function is rarely called. + # TODO(ddohan): Build this just in time: track what hasn't been hashed + # since last call, and add any new structures before checking contain. + def contains_structures(self, structures): + structures_in_population = set( + utils.hash_structure(sample.structure) for sample in self.samples + ) + return [ + utils.hash_structure(structure) in structures_in_population + for structure in structures + ] + + def deduplicate(self, select_best=False): + """De-duplicates Samples with identical structures. + + Args: + select_best: Whether to select the sample with the highest reward among + samples with the same structure. Otherwise, the sample that occurs + first will be selected. + + Returns: + A Population with de-duplicated samples. + """ + return Population(utils.deduplicate_samples(self, select_best=select_best)) + + def discard_infeasible(self): + """Returns a new `Population` with all infeasible samples removed.""" + return Population(sample for sample in self if not sample.infeasible) diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/deep_evolution_solver.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/deep_evolution_solver.py new file mode 100644 index 0000000..54d7ae2 --- /dev/null +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/deep_evolution_solver.py @@ -0,0 +1,1013 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Deep Evolution solver. + +Uses a neural net to predict the position and +mutation function to apply to a string. Neural net takes the string and predicts +1) [Batch x length] logits over positions in the string +2) [Batch x length x n_mutations] logits over mutation function for every +position in the string. +First, we sample the position from the position logits, take the logits +corresponding to the chosen position and sample the index of the mutation +function to apply to this position in the string. Currently, we apply +one mutation at a time. Finally, update the network parameters using REINFORCE +gradient estimator, where the advantage is the difference between parent and +child rewards. The log-likelihood is the sum of position and mutation +log-likelihoods. +By default, no selection is performed (we continue mutating the same batch, +use_selection_of_best = False). If use_selection_of_best=True, we choose best +samples from the previous batch and sample them with replacement to create +a new batch. +""" +import functools + +# from absl import logging +# import gin +import jax +from jax.example_libraries import stax +from jax.example_libraries.optimizers import adam +import jax.numpy as jnp +import jax.random as jrand +from jax.scipy.special import logsumexp +import numpy as np + +from poli_baselines.solvers.bayesian_optimization.amortized import base_solver +from poli_baselines.solvers.bayesian_optimization.amortized import data +from poli_baselines.solvers.bayesian_optimization.amortized import utils + + +def logsoftmax(x, axis=-1): + """Apply log softmax to an array of logits, log-normalizing along an axis.""" + return x - logsumexp(x, axis, keepdims=True) + + +def softmax(x, axis=-1): + return jnp.exp(logsoftmax(x, axis)) + + +def one_hot(x, k): + """Create a one-hot encoding of x of size k.""" + return jnp.eye(k)[x] + + +def gumbel_max_sampler(logits, temperature, rng): + """Sample fom categorical distribution using Gumbel-Max trick. + + Gumbel-Max trick: + https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ + https://arxiv.org/abs/1411.0030 + + Args: + logits: Unnormalized logits for categorical distribution. + [batch x n_mutations_to_sample x n_mutation_types] + temperature: temperature parameter for Gumbel-Max. The lower the + temperature, the closer the sample is to one-hot-encoding. + rng: Jax random number generator + + Returns: + class_assignments: Sampled class assignments [batch] + log_likelihoods: Log-likelihoods of the sampled mutations [batch] + """ + + # Normalize the logits + logits = logsoftmax(logits) + + gumbel_noise = jrand.gumbel(rng, logits.shape) + softmax_logits = (logits + gumbel_noise) / temperature + soft_assignments = softmax(softmax_logits, -1) + class_assignments = jnp.argmax(soft_assignments, -1) + assert len(class_assignments.shape) == 2 + # Output shape: [batch x num_mutations] + + return class_assignments + + +########################################## +# Mutation-related helper functions +def _mutate_position(structure, pos_mask, fn): + """Apply mutation fn to position specified by pos_mask.""" + structure = np.array(structure).copy() + pos_mask = np.array(pos_mask).astype(int) + structure[pos_mask == 1] = fn(structure[pos_mask == 1]) + return structure + + +def set_pos(x, pos_mask, val): + return _mutate_position(x, pos_mask, fn=lambda x: val) + + +def apply_mutations( + samples, mutation_types, pos_masks, mutations, use_assignment_mutations=False +): + """Apply the mutations specified by mutation types to the batch of strings. + + Args: + samples: Batch of strings [batch x str_length] + mutation_types: IDs of mutation types to be applied to each string + [Batch x num_mutations] + pos_masks: One-hot encoding [Batch x num_mutations x str_length] + of the positions to be mutate in each string. + "num_mutations" positions will be mutated per string. + mutations: A list of possible mutation functions. + Functions should follow the format: fn(x, domain, pos_mask), + use_assignment_mutations: bool. Whether mutations are defined as + "Set position X to character C". If use_assignment_mutations=True, + then vectorize procedure of applying mutations to the string. + The index of mutation type should be equal to the index of the character. + Gives considerable speed-up to this function. + + Returns: + perturbed_samples: Strings perturbed according to the mutation list. + """ + batch_size = samples.shape[0] + assert len(mutation_types) == batch_size + assert len(pos_masks) == batch_size + + str_length = samples.shape[1] + assert pos_masks.shape[-1] == str_length + + # Check that number of mutations is consistent in mutation_types and positions + assert mutation_types.shape[1] == pos_masks.shape[1] + + num_mutations = mutation_types.shape[1] + + # List of batched samples with 0,1,2,... mutations + # First element of the list contains original samples + # Last element has samples with all mutations applied to the string + perturbed_samples_with_i_mutations = [samples] + for i in range(num_mutations): + + perturbed_samples = [] + samples_to_perturb = perturbed_samples_with_i_mutations[-1] + + if use_assignment_mutations: + perturbed_samples = samples_to_perturb.copy() + mask = pos_masks[:, i].astype(int) + # Assumes mutations are defined as "Set position to the character C" + perturbed_samples[np.array(mask) == 1] = mutation_types[:, i] + else: + for j in range(batch_size): + sample = samples_to_perturb[j].copy() + + pos = pos_masks[j, i] + mut_id = mutation_types[j, i] + + mutation = mutations[int(mut_id)] + perturbed_samples.append(mutation(sample, pos)) + perturbed_samples = np.stack(perturbed_samples) + + assert perturbed_samples.shape == samples.shape + perturbed_samples_with_i_mutations.append(perturbed_samples) + + states = jnp.stack(perturbed_samples_with_i_mutations, 0) + assert states.shape == (num_mutations + 1,) + samples.shape + return states + + +########################################## +# pylint: disable=invalid-name +def OneHot(depth): + """Layer for transforming inputs to one-hot encoding.""" + + def init_fun(rng, input_shape): + del rng + return input_shape + (depth,), () + + def apply_fun(params, inputs, **kwargs): + del params, kwargs + # Perform one-hot encoding + return jnp.eye(depth)[inputs.astype(int)] + + return init_fun, apply_fun + + +def ExpandDims(axis=1): + """Layer for expanding dimensions.""" + + def init_fun(rng, input_shape): + del rng + input_shape = tuple(input_shape) + if axis < 0: + dims = len(input_shape) + new_axis = dims + 1 - axis + else: + new_axis = axis + return (input_shape[:new_axis] + (1,) + input_shape[new_axis:]), () + + def apply_fun(params, inputs, **kwargs): + del params, kwargs + return jnp.expand_dims(inputs, axis) + + return init_fun, apply_fun + + +def AssertNonZeroShape(): + """Layer for checking that no dimension has zero length.""" + + def init_fun(rng, input_shape): + del rng + return input_shape, () + + def apply_fun(params, inputs, **kwargs): + del params, kwargs + assert 0 not in inputs.shape + return inputs + + return init_fun, apply_fun + + +# pylint: enable=invalid-name + + +def squeeze_layer(axis=1): + """Layer for squeezing dimension along the axis.""" + + def init_fun(rng, input_shape): + del rng + if axis < 0: + raise ValueError("squeeze_layer: negative axis is not supported") + return (input_shape[:axis] + input_shape[(axis + 1) :]), () + + def apply_fun(params, inputs, **kwargs): + del params, kwargs + return inputs.squeeze(axis) + + return init_fun, apply_fun + + +def reduce_layer(reduce_fn=jnp.mean, axis=1): + """Apply reduction function to the array along axis.""" + + def init_fun(rng, input_shape): + del rng + assert axis >= 0 + assert len(input_shape) == 3 + return input_shape[: axis - 1] + input_shape[axis + 1 :], () + + def apply_fun(params, inputs, **kwargs): + del params, kwargs + return reduce_fn(inputs, axis=axis) + + return init_fun, apply_fun + + +def _create_positional_encoding( # pylint: disable=invalid-name + input_shape, max_len=10000 +): + """Helper: create positional encoding parameters.""" + d_feature = input_shape[-1] + pe = np.zeros((max_len, d_feature), dtype=np.float32) + position = np.arange(0, max_len)[:, np.newaxis] + div_term = np.exp(np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature)) + pe[:, 0::2] = np.sin(position * div_term) + # FIXME: Simon: the line below does not work for odd alphabet sizes! Get in touch with authors! + # pe[:, 1::2] = np.cos(position * div_term) + pe[:, 1::2] = np.cos(position * div_term[d_feature % 2 :]) + pe = pe[np.newaxis, :, :] # [1, max_len, d_feature] + return jnp.array(pe) # These are trainable parameters, initialized as above. + + +def positional_encoding(): + """Concatenate positional encoding to the last dimension.""" + + def init_fun(rng, input_shape): + del rng + input_shape_for_enc = input_shape + params = _create_positional_encoding(input_shape_for_enc) + last_dim = input_shape[-1] + params.shape[-1] + return input_shape[:-1] + (last_dim,), (params,) + + def apply_fun(params, inputs, **kwargs): + del kwargs + assert inputs.ndim == 4 + params = params[0] + symbol_size = inputs.shape[-2] + enc = params[None, :, :symbol_size, :] + enc = jnp.repeat(enc, inputs.shape[0], 0) + return jnp.concatenate((inputs, enc), -1) + + return init_fun, apply_fun + + +def cnn( + conv_depth=300, + kernel_size=5, + n_conv_layers=2, + across_batch=False, + add_pos_encoding=False, +): + """Build convolutional neural net.""" + # Input shape: [batch x length x depth] + if across_batch: + extra_dim = 0 + else: + extra_dim = 1 + layers = [ExpandDims(axis=extra_dim)] + if add_pos_encoding: + layers.append(positional_encoding()) + + for _ in range(n_conv_layers): + layers.append( + stax.Conv(conv_depth, (1, kernel_size), padding="same", strides=(1, 1)) + ) + layers.append(stax.Relu) + layers.append(AssertNonZeroShape()) + layers.append(squeeze_layer(axis=extra_dim)) + return stax.serial(*layers) + + +def build_model_stax( + output_size, + n_dense_units=300, + conv_depth=300, + n_conv_layers=2, + n_dense_layers=0, + kernel_size=5, + across_batch=False, + add_pos_encoding=False, + mean_over_pos=False, + mode="train", +): + """Build a model with convolutional layers followed by dense layers.""" + del mode + layers = [ + cnn( + conv_depth=conv_depth, + n_conv_layers=n_conv_layers, + kernel_size=kernel_size, + across_batch=across_batch, + add_pos_encoding=add_pos_encoding, + ) + ] + for _ in range(n_dense_layers): + layers.append(stax.Dense(n_dense_units)) + layers.append(stax.Relu) + + layers.append(stax.Dense(output_size)) + + if mean_over_pos: + layers.append(reduce_layer(jnp.mean, axis=1)) + init_random_params, predict = stax.serial(*layers) + return init_random_params, predict + + +def sample_log_probs_top_k(log_probs, rng, temperature=1.0, k=1): + """Sample categorical distribution of log probs using gumbel max trick.""" + noise = jax.random.gumbel(rng, shape=log_probs.shape) + perturbed = (log_probs + noise) / temperature + samples = jnp.argsort(perturbed)[Ellipsis, -k:] + return samples + + +@jax.jit +def gather_positions(idx_to_gather, logits): + """Collect logits corresponding to the positions in the string. + + Used for collecting logits for: + 1) positions in the string (depth = 1) + 2) mutation types (depth = n_mut_types) + + Args: + idx_to_gather: [batch_size x num_mutations] Indices of the positions + in the string to gather logits for. + logits: [batch_size x str_length x depth] Logits to index. + + Returns: + Logits corresponding to the specified positions in the string: + [batch_size, num_mutations, depth] + """ + assert idx_to_gather.shape[0] == logits.shape[0] + assert idx_to_gather.ndim == 2 + assert logits.ndim == 3 + + batch_size, num_mutations = idx_to_gather.shape + batch_size, str_length, depth = logits.shape + + oh = one_hot(idx_to_gather, str_length) + assert oh.shape == (batch_size, num_mutations, str_length) + + oh = oh[Ellipsis, None] + logits = logits[:, None, :, :] + assert oh.shape == (batch_size, num_mutations, str_length, 1) + assert logits.shape == (batch_size, 1, str_length, depth) + + # Perform element-wise multiplication (with broadcasting), + # then sum over str_length dimension + result = jnp.sum(oh * logits, axis=-2) + assert result.shape == (batch_size, num_mutations, depth) + return result + + +class JaxMutationPredictor(object): + """Implements training and predicting from a Jax model. + + Attributes: + output_size: Tuple containing the sizes of components to predict + loss_fn: Loss function. + Format of the loss fn: fn(params, batch, mutations, problem, predictor) + loss_grad_fn: Gradient of the loss function + temperature: temperature parameter for Gumbel-Max sampler. + learning_rate: Learning rate for optimizer. + batch_size: Batch size of input + model_fn: Function which builds the model forward pass. Must have arguments + `vocab_size`, `max_len`, and `mode` and return Jax float arrays. + params: weights of the neural net + make_state: function to make optimizer state given the network parameters + rng: Jax random number generator + """ + + def __init__( + self, + vocab_size, + output_size, + loss_fn, + rng, + temperature=1, + learning_rate=0.001, + conv_depth=300, + n_conv_layers=2, + n_dense_units=300, + n_dense_layers=0, + kernel_size=5, + across_batch=False, + add_pos_encoding=False, + mean_over_pos=False, + model_fn=build_model_stax, + ): + self.output_size = output_size + self.temperature = temperature + + # Setup randomness. + self.rng = rng + + model_settings = { + "output_size": output_size, + "n_dense_units": n_dense_units, + "n_dense_layers": n_dense_layers, + "conv_depth": conv_depth, + "n_conv_layers": n_conv_layers, + "across_batch": across_batch, + "kernel_size": kernel_size, + "add_pos_encoding": add_pos_encoding, + "mean_over_pos": mean_over_pos, + "mode": "train", + } + + self._model_init, model_train = model_fn(**model_settings) + self._model_train = jax.jit(model_train) + + model_settings["mode"] = "eval" + _, model_predict = model_fn(**model_settings) + self._model_predict = jax.jit(model_predict) + + self.rng, subrng = jrand.split(self.rng) + _, init_params = self._model_init(subrng, (-1, -1, vocab_size)) + self.params = init_params + + # Setup parameters for model and optimizer + self.make_state, self._opt_update_state, self._get_params = adam(learning_rate) + + self.loss_fn = functools.partial(loss_fn, run_model_fn=self.run_model) + self.loss_grad_fn = jax.grad(self.loss_fn) + + # Track steps of optimization so far. + self._step_idx = 0 + + def update_step(self, rewards, inputs, actions): + """Performs a single update step on a batch of samples. + + Args: + rewards: Batch [batch] of rewards for perturbed samples. + inputs: Batch [batch x length] of original samples + actions: actions applied on the samples + + Raises: + ValueError: if any inputs are the wrong shape. + """ + + grad_update = self.loss_grad_fn( + self.params, + rewards=rewards, + inputs=inputs, + actions=actions, + ) + + old_params = self.params + state = self.make_state(old_params) + state = self._opt_update_state(self._step_idx, grad_update, state) + + self.params = self._get_params(state) + del old_params, state + + self._step_idx += 1 + + def __call__(self, x, mode="eval"): + """Calls predict function of model. + + Args: + x: Batch of input samples. + mode: Mode for running the network: "train" or "eval" + + Returns: + A list of tuples (class weights, log likelihood) for each of + output components predicted by the model. + """ + return self.run_model(x, self.params, mode="eval") + + def run_model(self, x, params, mode="eval"): + """Run the Jax model. + + This function is used in __call__ to run the model in "eval" mode + and in the loss function to run the model in "train" mode. + + Args: + x: Batch of input samples. + params: Network parameters + mode: Mode for running the network: "train" or "eval" + + Returns: + Jax neural network output. + """ + if mode == "train": + model_fn = self._model_train + else: + model_fn = self._model_predict + self.rng, subrng = jax.random.split(self.rng) + return model_fn(params, inputs=x, rng=subrng) + + +######################################### +# Loss function +def reinforce_loss(rewards, log_likelihood): + """Loss function for Jax model. + + Args: + rewards: List of rewards [batch] for the perturbed samples. + log_likelihood: Log-likelihood of perturbations + + Returns: + Scalar loss. + """ + rewards = jax.lax.stop_gradient(rewards) + + # In general case, we assume that the loss is not differentiable + # Use REINFORCE + reinforce_estim = rewards * log_likelihood + # Take mean over the number of applied mutations, then across the batch + return -jnp.mean(jnp.mean(reinforce_estim, 1), 0) + + +def compute_entropy(log_probs): + """Compute entropy of a set of log_probs.""" + return -jnp.mean(jnp.mean(stax.softmax(log_probs) * log_probs, axis=-1)) + + +def compute_advantage(params, critic_fn, rewards, inputs): + """Compute the advantage: difference between rewards and predicted value. + + Args: + params: parameters for the critic neural net + critic_fn: function to run critic neural net + rewards: rewards for the perturbed samples + inputs: original samples, used as input to the Jax model + + Returns: + advantage: [batch_size x num_mutations] + """ + assert inputs.ndim == 4 + + num_mutations, batch_size, str_length, vocab_size = inputs.shape + + inputs_reshaped = inputs.reshape( + (num_mutations * batch_size, str_length, vocab_size) + ) + + predicted_value = critic_fn(inputs_reshaped, params, mode="train") + assert predicted_value.shape == (num_mutations * batch_size, 1) + predicted_value = predicted_value.reshape((num_mutations, batch_size)) + + assert rewards.shape == (batch_size,) + rewards = jnp.repeat(rewards[None, :], num_mutations, 0) + assert rewards.shape == (num_mutations, batch_size) + + advantage = rewards - predicted_value + advantage = jnp.transpose(advantage) + assert advantage.shape == (batch_size, num_mutations) + return advantage + + +def value_loss_fn(params, run_model_fn, rewards, inputs, actions=None): + """Compute the loss for the value function. + + Args: + params: parameters for the Jax model + run_model_fn: Jax model to run + rewards: rewards for the perturbed samples + inputs: original samples, used as input to the Jax model + actions: not used + + Returns: + A scalar loss. + """ + del actions + advantage = compute_advantage(params, run_model_fn, rewards, inputs) + advantage = advantage**2 + + return jnp.sqrt(jnp.mean(advantage)) + + +def split_mutation_predictor_output(output): + return stax.logsoftmax(output[:, :, -1]), stax.logsoftmax(output[:, :, :-1]) + + +def run_model_and_compute_reinforce_loss( + params, run_model_fn, rewards, inputs, actions, n_mutations, entropy_weight=0.1 +): + """Run Jax model and compute REINFORCE loss. + + Jax can compute the gradients of the model only if the model is called inside + the loss function. Here we call the Jax model, re-compute the log-likelihoods, + take log-likelihoods of the mutations and positions sampled before in + _propose function of the solver, and compute the loss. + + Args: + params: parameters for the Jax model + run_model_fn: Jax model to run + rewards: rewards for the perturbed samples + inputs: original samples, used as input to the Jax model + actions: Tuple (mut_types [Batch], positions [Batch]) of mutation types + and positions sampled during the _propose() step of evolution solver. + n_mutations: Number of mutations. Used for one-hot encoding of mutations + entropy_weight: Weight on the entropy term added to the loss. + + Returns: + A scalar loss. + + """ + mut_types, positions = actions + mut_types_one_hot = one_hot(mut_types, n_mutations) + + batch_size, str_length, _ = inputs.shape + assert mut_types.shape[0] == inputs.shape[0] + batch_size, num_mutations = mut_types.shape + assert mut_types.shape == positions.shape + assert mut_types.shape == rewards.shape + + output = run_model_fn(inputs, params, mode="train") + pos_log_probs, all_mut_log_probs = split_mutation_predictor_output(output) + assert pos_log_probs.shape == (batch_size, str_length) + pos_log_probs = jnp.expand_dims(pos_log_probs, -1) + + pos_log_likelihoods = gather_positions(positions, pos_log_probs) + assert pos_log_likelihoods.shape == (batch_size, num_mutations, 1) + + # Sum over number of positions + pos_log_likelihoods = jnp.sum(pos_log_likelihoods, -1) + + # all_mut_log_probs shape: [batch_size, str_length, n_mut_types] + assert all_mut_log_probs.shape[:2] == (batch_size, str_length) + + # Get mutation logits corresponding to the chosen positions + mutation_logprobs = gather_positions(positions, all_mut_log_probs) + + # Get log probs corresponding to the selected mutations + mut_log_likelihoods_oh = mutation_logprobs * mut_types_one_hot + + # Sum over mutation types + mut_log_likelihoods = jnp.sum(mut_log_likelihoods_oh, -1) + assert mut_log_likelihoods.shape == (batch_size, num_mutations) + + joint_log_likelihood = mut_log_likelihoods + pos_log_likelihoods + assert joint_log_likelihood.shape == (batch_size, num_mutations) + + loss = reinforce_loss(rewards, joint_log_likelihood) + loss -= entropy_weight * compute_entropy(mutation_logprobs) + return loss + + +############################################ +# MutationPredictorSolver +def initialize_uniformly(domain, batch_size, random_state): + return domain.sample_uniformly(batch_size, seed=random_state) + + +# @gin.configurable +class MutationPredictorSolver(base_solver.BaseSolver): + """Choose the mutation operator conditioned on the sample. + + Sample from categorical distribution over available mutation operators + using Gumbel-Max trick + """ + + def __init__(self, domain, model_fn=build_model_stax, random_state=0, **kwargs): + """Constructs solver. + + Args: + domain: discrete domain + model_fn: Function which builds the forward pass of predictor model. + random_state: Random state to initialize jax & np RNGs. + **kwargs: kwargs passed to config. + """ + super(MutationPredictorSolver, self).__init__( + domain=domain, random_state=random_state, **kwargs + ) + self.rng = jrand.PRNGKey(random_state) + self.rng, rng = jax.random.split(self.rng) + + if self.domain.length < self.cfg.num_mutations: + # logging.warning("Number of mutations to perform per string exceeds string" + # " length. The number of mutation is set to be equal to " + # "the string length.") + self.cfg.num_mutations = self.domain.length + + # Right now the mutations are defined as "Set position X to character C". + # It allows to vectorize applying mutations to the string and speeds up + # the solver. + # If using other types of mutations, set self.use_assignment_mut=False. + self.mutations = [] + for val in range(self.domain.vocab_size): + self.mutations.append(functools.partial(set_pos, val=val)) + self.use_assignment_mut = True + + mut_loss_fn = functools.partial( + run_model_and_compute_reinforce_loss, n_mutations=len(self.mutations) + ) + + # Predictor that takes the input string + # Outputs the weights over the 1) mutations types 2) position in string + if self.cfg.pretrained_model is None: + self._mut_predictor = self.cfg.predictor( + vocab_size=self.domain.vocab_size, + output_size=len(self.mutations) + 1, + loss_fn=mut_loss_fn, + rng=rng, + model_fn=build_model_stax, + conv_depth=self.cfg.actor_conv_depth, + n_conv_layers=self.cfg.actor_n_conv_layers, + n_dense_units=self.cfg.actor_n_dense_units, + n_dense_layers=self.cfg.actor_n_dense_layers, + across_batch=self.cfg.actor_across_batch, + add_pos_encoding=self.cfg.actor_add_pos_encoding, + kernel_size=self.cfg.actor_kernel_size, + learning_rate=self.cfg.actor_learning_rate, + ) + + if self.cfg.use_actor_critic: + self._value_predictor = self.cfg.predictor( + vocab_size=self.domain.vocab_size, + output_size=1, + rng=rng, + loss_fn=value_loss_fn, + model_fn=build_model_stax, + mean_over_pos=True, + conv_depth=self.cfg.critic_conv_depth, + n_conv_layers=self.cfg.critic_n_conv_layers, + n_dense_units=self.cfg.critic_n_dense_units, + n_dense_layers=self.cfg.critic_n_dense_layers, + across_batch=self.cfg.critic_across_batch, + add_pos_encoding=self.cfg.critic_add_pos_encoding, + kernel_size=self.cfg.critic_kernel_size, + learning_rate=self.cfg.critic_learning_rate, + ) + + else: + self._value_predictor = None + else: + self._mut_predictor, self._value_predictor = self.cfg.pretrained_model + + self._data_for_grad_update = [] + self._initialized = False + + def _config(self): + cfg = super(MutationPredictorSolver, self)._config() + + cfg.update( + dict( + predictor=JaxMutationPredictor, + temperature=1.0, + initialize_dataset_fn=initialize_uniformly, + elite_set_size=10, + use_random_network=False, + exploit_with_best=True, + use_selection_of_best=False, + pretrained_model=None, + # Indicator to BO to pass in previous weights. + # As implemented in cl/318101597. + warmstart=True, + use_actor_critic=False, + num_mutations=5, + # Hyperparameters for actor + actor_learning_rate=0.001, + actor_conv_depth=300, + actor_n_conv_layers=1, + actor_n_dense_units=100, + actor_n_dense_layers=0, + actor_kernel_size=5, + actor_across_batch=False, + actor_add_pos_encoding=True, + # Hyperparameters for critic + critic_learning_rate=0.001, + critic_conv_depth=300, + critic_n_conv_layers=1, + critic_n_dense_units=300, + critic_n_dense_layers=0, + critic_kernel_size=5, + critic_across_batch=False, + critic_add_pos_encoding=True, + ) + ) + return cfg + + def _get_unique(self, samples): + unique_population = data.Population() + unique_structures = set() + for sample in samples: + hashed_structure = utils.hash_structure(sample.structure) + if hashed_structure in unique_structures: + continue + unique_structures.add(hashed_structure) + unique_population.add_samples([sample]) + return unique_population + + def _get_best_samples_from_last_batch( + self, population, n=1, discard_duplicates=True + ): + best_samples = population.get_last_batch().best_n( + n, discard_duplicates=discard_duplicates + ) + return best_samples.structures, best_samples.rewards + + def _select(self, population): + if self.cfg.use_selection_of_best: + # Choose best samples from the previous batch + structures, rewards = self._get_best_samples_from_last_batch( + population, self.cfg.elite_set_size + ) + + # Choose the samples to perturb with replacement + idx = np.random.choice( + len(structures), self.batch_size, replace=True + ) # pytype: disable=attribute-error # trace-all-classes + selected_structures = np.stack([structures[i] for i in idx]) + selected_rewards = np.stack([rewards[i] for i in idx]) + return selected_structures, selected_rewards + else: + # Just return the samples from the previous batch -- no selection + last_batch = population.get_last_batch() + structures = np.array([x.structure for x in last_batch]) + rewards = np.array([x.reward for x in last_batch]) + + if ( + len(last_batch) > self.batch_size + ): # pytype: disable=attribute-error # trace-all-classes + # Subsample the data + idx = np.random.choice( + len(last_batch), self.batch_size, replace=False + ) # pytype: disable=attribute-error # trace-all-classes + structures = np.stack([structures[i] for i in idx]) + rewards = np.stack([rewards[i] for i in idx]) + return structures, rewards + + def propose( + self, num_samples, population=None, pending_samples=None + ): # pytype: disable=signature-mismatch # overriding-parameter-count-checks + # Initialize population randomly. + if self._initialized and population: + if num_samples != self.batch_size: + raise ValueError("Must maintain constant batch size between runs.") + counter = population.max_batch_index + if counter > 0: + if not self.cfg.use_random_network: + self._update_params(population) + else: + self.batch_size = num_samples + self._initialized = True + return self.cfg.initialize_dataset_fn( + self.domain, num_samples, random_state=self._random_state + ) + + # Choose best samples so far -- [elite_set_size] + samples_to_perturb, parent_rewards = self._select(population) + + perturbed, actions, mut_predictor_input = self._perturb(samples_to_perturb) + + if not self.cfg.use_random_network: + self._data_for_grad_update.append( + { + "batch_index": population.current_batch_index + 1, + "mut_predictor_input": mut_predictor_input, + "actions": actions, + "parent_rewards": parent_rewards, + } + ) + + return np.asarray(perturbed) + + def _perturb(self, parents, mode="train"): + length = parents.shape[1] + assert length == self.domain.length + + parents_one_hot = one_hot(parents, self.domain.vocab_size) + + output = self._mut_predictor(parents_one_hot) + pos_log_probs, all_mut_log_probs = split_mutation_predictor_output(output) + + self.rng, subrng = jax.random.split(self.rng) + positions = sample_log_probs_top_k( + pos_log_probs, + subrng, + k=self.cfg.num_mutations, + temperature=self.cfg.temperature, + ) + + pos_masks = one_hot(positions, length) + + mutation_logprobs = gather_positions(positions, all_mut_log_probs) + assert mutation_logprobs.shape == ( + output.shape[0], + self.cfg.num_mutations, + output.shape[-1] - 1, + ) + + self.rng, subrng = jax.random.split(self.rng) + mutation_types = gumbel_max_sampler( + mutation_logprobs, self.cfg.temperature, subrng + ) + + states = apply_mutations( + parents, + mutation_types, + pos_masks, + self.mutations, + use_assignment_mutations=self.use_assignment_mut, + ) + # states shape: [num_mutations+1, batch, str_length] + # states[0] are original samples with no mutations + # states[-1] are strings with all mutations applied to them + states_oh = one_hot(states, self.domain.vocab_size) + # states_oh shape: [n_mutations+1, batch, str_length, vocab_size] + perturbed = states[-1] + + return perturbed, (mutation_types, positions), states_oh + + def _update_params(self, population): + if not self._data_for_grad_update: + return + + dat = self._data_for_grad_update.pop() + assert dat["batch_index"] == population.current_batch_index + + child_rewards = jnp.array(population.get_last_batch().rewards) + parent_rewards = dat["parent_rewards"] + + all_states = dat["mut_predictor_input"] + # all_states shape: [num_mutations, batch_size, str_length, vocab_size] + + # TODO(rubanova): rescale the rewards + terminal_rewards = child_rewards + + if self.cfg.use_actor_critic: + # Update the value function + # Compute the difference between predicted value of intermediate states + # and the final reward. + self._value_predictor.update_step( + rewards=terminal_rewards, + inputs=all_states[:-1], + actions=None, + ) + advantage = compute_advantage( + self._value_predictor.params, + self._value_predictor.run_model, + terminal_rewards, + all_states[:-1], + ) + else: + advantage = child_rewards - parent_rewards + advantage = jnp.repeat(advantage[:, None], self.cfg.num_mutations, 1) + + advantage = jax.lax.stop_gradient(advantage) + + # Perform policy update. + # Compute policy on the original samples, like in _perturb function. + self._mut_predictor.update_step( + rewards=advantage, inputs=all_states[0], actions=dat["actions"] + ) + + del all_states, advantage + + @property + def trained_model(self): + return (self._mut_predictor, self._value_predictor) diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/domains.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/domains.py new file mode 100644 index 0000000..30c9599 --- /dev/null +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/domains.py @@ -0,0 +1,273 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Specifications for different types of input/output domains.""" + +import abc +import collections +from collections.abc import Iterable + +import gin +import numpy as np +import six +from six.moves import range + +from poli_baselines.solvers.bayesian_optimization.amortized import utils + +BOS_TOKEN = "<" # Beginning of sequence token. +EOS_TOKEN = ">" # End of sequence token. +PAD_TOKEN = "_" # End of sequence token. +MASK_TOKEN = "*" # End of sequence token. +SEP_TOKEN = "|" # A special token for separating tokens for serialization. + + +@gin.configurable +class Vocabulary(object): + """Basic vocabulary used to represent output tokens for domains.""" + + def __init__( + self, + tokens, + include_bos=False, + include_eos=False, + include_pad=False, + include_mask=False, + bos_token=BOS_TOKEN, + eos_token=EOS_TOKEN, + pad_token=PAD_TOKEN, + mask_token=MASK_TOKEN, + ): + """A token vocabulary. + + Args: + tokens: An list of tokens to put in the vocab. If an int, will be + interpreted as the number of tokens and '0', ..., 'tokens-1' will be + used as tokens. + include_bos: Whether to append `bos_token` to `tokens` that marks the + beginning of a sequence. + include_eos: Whether to append `eos_token` to `tokens` that marks the + end of a sequence. + include_pad: Whether to append `pad_token` to `tokens` to marks past end + of sequence. + include_mask: Whether to append `mask_token` to `tokens` to mark masked + positions. + bos_token: A special token than marks the beginning of sequence. + Ignored if `include_bos == False`. + eos_token: A special token than marks the end of sequence. + Ignored if `include_eos == False`. + pad_token: A special token than marks past the end of sequence. + Ignored if `include_pad == False`. + mask_token: A special token than marks MASKED positions for e.g. BERT. + Ignored if `include_mask == False`. + """ + if not isinstance(tokens, Iterable): + tokens = range(tokens) + tokens = [str(token) for token in tokens] + if include_bos: + tokens.append(bos_token) + if include_eos: + tokens.append(eos_token) + if include_pad: + tokens.append(pad_token) + if include_mask: + tokens.append(mask_token) + if len(set(tokens)) != len(tokens): + raise ValueError("tokens not unique!") + special_tokens = sorted(set(tokens) & set([SEP_TOKEN])) + if special_tokens: + raise ValueError( + f"tokens contains reserved special tokens: {special_tokens}!" + ) + + self._tokens = tokens + self._token_ids = list(range(len(self._tokens))) + self._id_to_token = collections.OrderedDict(zip(self._token_ids, self._tokens)) + self._token_to_id = collections.OrderedDict(zip(self._tokens, self._token_ids)) + self._bos_token = bos_token if include_bos else None + self._eos_token = eos_token if include_eos else None + self._mask_token = mask_token if include_mask else None + self._pad_token = pad_token if include_pad else None + + def __len__(self): + return len(self._tokens) + + @property + def tokens(self): + """Return the tokens of the vocabulary.""" + return list(self._tokens) + + @property + def token_ids(self): + """Return the tokens ids of the vocabulary.""" + return list(self._token_ids) + + @property + def bos(self): + """Returns the index of the BOS token or None if unspecified.""" + return None if self._bos_token is None else self._token_to_id[self._bos_token] + + @property + def eos(self): + """Returns the index of the EOS token or None if unspecified.""" + return None if self._eos_token is None else self._token_to_id[self._eos_token] + + @property + def mask(self): + """Returns the index of the MASK token or None if unspecified.""" + return None if self._mask_token is None else self._token_to_id[self._mask_token] + + @property + def pad(self): + """Returns the index of the PAD token or None if unspecified.""" + return None if self._pad_token is None else self._token_to_id[self._pad_token] + + def is_valid(self, value): + """Tests if a value is a valid token id and returns a bool.""" + return value in self._token_ids + + def are_valid(self, values): + """Tests if values are valid token ids and returns an array of bools.""" + return np.array([self.is_valid(value) for value in values]) + + def encode(self, tokens): + """Maps an iterable of string tokens to a list of integer token ids.""" + if six.PY3 and isinstance(tokens, bytes): + # Always use Unicode in Python 3. + tokens = tokens.decode("utf-8") + return [self._token_to_id[token] for token in tokens] + + def decode(self, values, stop_at_eos=False, as_str=True): + """Maps an iterable of integer token ids to string tokens. + + Args: + values: An iterable of token ids. + stop_at_eos: Whether to ignore all values after the first EOS token id. + as_str: Whether to return a list of tokens or a concatenated string. + + Returns: + A string of tokens or a list of tokens if `as_str == False`. + """ + if stop_at_eos and self.eos is None: + raise ValueError("EOS unspecified!") + tokens = [] + for value in values: + value = int(value) # Requires if value is a scalar tensor. + if stop_at_eos and value == self.eos: + break + tokens.append(self._id_to_token[value]) + return "".join(tokens) if as_str else tokens + + +@six.add_metaclass(abc.ABCMeta) +class Domain(object): + """Base class of problem domains, which specifies the set of valid objects.""" + + @property + def mask_fn(self): + """Returns a masking function or None.""" + + @abc.abstractmethod + def is_valid(self, sample): + """Tests if the given sample is valid for this domain.""" + + def are_valid(self, samples): + """Tests if the given samples are valid for this domain.""" + return np.array([self.is_valid(sample) for sample in samples]) + + +class DiscreteDomain(Domain): + """Base class for discrete domains: sequences of categorical variables.""" + + def __init__(self, vocab): + self._vocab = vocab + + @property + def vocab_size(self): + return len(self.vocab) + + @property + def vocab(self): + return self._vocab # pytype: disable=attribute-error # trace-all-classes + + def encode(self, samples, **kwargs): + """Maps a list of string tokens to a list of lists of integer token ids.""" + return [self.vocab.encode(sample, **kwargs) for sample in samples] + + def decode(self, samples, **kwargs): + """Maps list of lists of integer token ids to list of strings.""" + return [self.vocab.decode(sample, **kwargs) for sample in samples] + + +@gin.configurable +class FixedLengthDiscreteDomain(DiscreteDomain): + """Output is a fixed length discrete sequence.""" + + def __init__(self, vocab_size=None, length=None, vocab=None): + """Creates an instance of this class. + + Args: + vocab_size: An optional integer for constructing a vocab of this size. + If provided, `vocab` must be `None`. + length: The length of the domain (required). + vocab: The `Vocabulary` of the domain. If provided, `vocab_size` must be + `None`. + + Raises: + ValueError: If neither `vocab_size` nor `vocab` is provided. + ValueError: If `length` if not provided. + """ + if length is None: + raise ValueError("length must be provided!") + if not (vocab_size is None) ^ (vocab is None): + raise ValueError("Exactly one of vocab_size of vocab must be specified!") + self._length = length + if vocab is None: + vocab = Vocabulary(vocab_size) + super(FixedLengthDiscreteDomain, self).__init__(vocab) + + @property + def length(self): + return self._length + + @property + def size(self): + """The number of structures in the Domain.""" + return self.vocab_size**self.length + + def is_valid(self, sequence): + return len(sequence) == self.length and self.vocab.are_valid(sequence).all() + + def sample_uniformly(self, num_samples, seed=None): + random_state = utils.get_random_state(seed) + return np.int32( + random_state.randint( + size=[num_samples, self.length], low=0, high=self.vocab_size + ) + ) + + def index_to_structure(self, index): + """Given an integer and target length, encode into structure.""" + structure = np.zeros(self.length, dtype=np.int32) + tokens = [ + int(token, base=len(self.vocab)) + for token in np.base_repr(index, base=len(self.vocab)) + ] + structure[-len(tokens) :] = tokens + return structure + + def structure_to_index(self, structure): + """Returns the index of a sequence over a vocabulary of size `vocab_size`.""" + structure = np.asarray(structure)[::-1] + return np.sum(structure * np.power(len(self.vocab), range(len(structure)))) diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/requirements.txt b/src/poli_baselines/solvers/bayesian_optimization/amortized/requirements.txt new file mode 100644 index 0000000..d8b5764 --- /dev/null +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/requirements.txt @@ -0,0 +1,6 @@ +tensorflow +gin-config >= 0.3.0 +pandas >= 1.0.5 +attrs >= 19.3.0 +jax >= 0.1.71 +jaxlib >= 0.1.48 \ No newline at end of file diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/utils.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/utils.py new file mode 100644 index 0000000..af76f4a --- /dev/null +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/utils.py @@ -0,0 +1,381 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions.""" + +import collections +import copy +import functools +import inspect +import logging as std_logging +import pprint +import random +import types +import uuid + +from absl import logging +import numpy as np +import pandas as pd +import tensorflow.compat.v1 as tf + +MIN_INT = np.iinfo(np.int64).min +MAX_INT = np.iinfo(np.int64).max +MIN_FLOAT = np.finfo(np.float32).min +MAX_FLOAT = np.finfo(np.float32).max + + +# TODO(ddohan): FrozenConfig type +class Config(dict): + """a dictionary that supports dot and dict notation. + + Create: + d = Config() + d = Config({'val1':'first'}) + + Get: + d.val2 + d['val2'] + + Set: + d.val2 = 'second' + d['val2'] = 'second' + """ + + __getattr__ = dict.__getitem__ + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + def __str__(self): + return pprint.pformat(self) + + def __deepcopy__(self, memo): + return self.__class__( + [(copy.deepcopy(k, memo), copy.deepcopy(v, memo)) for k, v in self.items()] + ) + + +def get_logger(name="", with_absl=True, level=logging.INFO): + """Creates a logger.""" + logger = std_logging.getLogger(name) + if with_absl: + logger.addHandler(logging.get_absl_handler()) + logger.propagate = False + logger.setLevel(level) + return logger + + +def get_random_state(seed_or_state): + """Returns a np.random.RandomState given an integer seed or RandomState.""" + if isinstance(seed_or_state, int): + return np.random.RandomState(seed_or_state) + elif seed_or_state is None: + # This returns the current global np random state. + return np.random.random.__self__ + elif not isinstance(seed_or_state, np.random.RandomState): + raise ValueError( + "Numpy RandomState or integer seed expected! Got: %s" % seed_or_state + ) + else: + return seed_or_state + + +def set_seed(seed): + """Sets global Numpy, Tensorboard, and random seed.""" + np.random.seed(seed) + tf.set_random_seed(seed) + random.seed(seed, version=1) + + +def to_list(values, none_to_list=True): + """Converts `values` of any type to a `list`.""" + if ( + hasattr(values, "__iter__") + and not isinstance(values, str) + and not isinstance(values, dict) + ): + return list(values) + elif none_to_list and values is None: + return [] + else: + return [values] + + +def to_array(values): + """Converts input values to a np.ndarray.""" + if tf.executing_eagerly() and tf.is_tensor(values): + return values.numpy() + else: + return np.asarray(values) + + +def arrays_from_dataset(dataset): + """Converts a tf.data.Dataset to nested np.ndarrays.""" + return tf.nest.map_structure( + lambda tensor: np.asarray(tensor), # pylint: disable=unnecessary-lambda + tensors_from_dataset(dataset), + ) + + +def dataset_from_tensors(tensors): + """Converts nested tf.Tensors or np.ndarrays to a tf.Data.Dataset.""" + if isinstance(tensors, types.GeneratorType) or isinstance(tensors, list): + tensors = tuple(tensors) + return tf.data.Dataset.from_tensor_slices(tensors) + + +def random_choice(values, size=None, random_state=None, **kwargs): + """Enables safer sampling from a list of values than `np.random.choice`. + + `np.random.choice` fails when trying to sample, e.g., from a list of + + indices instead of sampling from `values` directly. + + Args: + values: An iterable of values. + size: The sample size. + random_state: An integer seed or a `np.random.RandomState`. + **kwargs: Named arguments passed to `np.random.choice`. + + Returns: + As single element from `values` if `size is None`, otherwise a list of + samples from `values` of length `size`. + """ + random_state = get_random_state(random_state) + values = list(values) + effective_size = 1 if size is None else size + idxs = random_state.choice(range(len(values)), size=effective_size, **kwargs) + samples = [values[idx] for idx in idxs] + return samples[0] if size is None else samples + + +def random_shuffle(values, random_state=None): + """Shuffles a list of `values` out-of-place.""" + return random_choice( + values, size=len(values), replace=False, random_state=random_state + ) + + +def get_tokens(sequences, lower=False, upper=False): + """Returns a sorted list of all unique characters of a list of sequences. + + Args: + sequences: An iterable of string sequences. + lower: Whether to lower-case sequences before computing tokens. + upper: Whether to upper-case sequences before computing tokens. + + Returns: + A sorted list of all characters that appear in `sequences`. + """ + if lower and upper: + raise ValueError("lower and upper must not be specified at the same time!") + if lower: + sequences = [seq.lower() for seq in sequences] + if upper: + sequences = [seq.upper() for seq in sequences] + return sorted(set.union(*[set(seq) for seq in sequences])) + + +def tensors_from_dataset(dataset): + """Converts a tf.data.Dataset to nested tf.Tensors.""" + tensors = list(dataset) + if tensors: + return tf.nest.map_structure(lambda *tensors: tf.stack(tensors), *tensors) + # Return empty tensors if the dataset is empty. + shapes_dtypes = zip( + tf.nest.flatten(dataset.output_shapes), tf.nest.flatten(dataset.output_types) + ) + tensors = [ + tf.zeros(shape=[0] + shape.as_list(), dtype=dtype) + for shape, dtype in shapes_dtypes + ] + return tf.nest.pack_sequence_as(dataset.output_shapes, tensors) + + +def hash_structure(structure): + """Hashes a structure (n-d numpy array) of either ints or floats to a string. + + Args: + structure: A structure of ints that is castable to np.int32 (examples + include an np.int32, np.int64, an int32 eager tf.Tensor, + a list of python ints, etc.) or a structure of floats that is castable + to np.float32. Here, we say that an array is castable if it can be + converted to the target type, perhaps with some loss of precision + (e.g. float64 to float32). See np.can_cast(..., casting='same_kind'). + + Returns: + A string hash for the structure. The hash will depend on the + high-level type (int vs. float), but not the precision of such a type + (int32 vs. int64). + """ + array = np.asarray(structure) + if np.can_cast(array, np.int32, "same_kind"): + return np.int32(array).tostring() + elif np.can_cast(array, np.float32, "same_kind"): + return np.float32(array).tostring() + raise ValueError( + "%s can not be safely cast to np.int32 or " "np.float32" % str(structure) + ) + + +def create_unique_id(): + """Creates a unique hex ID.""" + return uuid.uuid1().hex + + +def deduplicate_samples(samples, select_best=False): + """De-duplicates Samples with identical structures. + + Args: + samples: An iterable of `data.Sample`s. + select_best: Whether to select the sample with the highest reward among + samples with the same structure. Otherwise, the sample that occurs + first will be selected. + + Returns: + A list of Samples. + """ + + def _sort(to_sort): + return ( + sorted(to_sort, key=lambda sample: sample.reward, reverse=True) + if select_best + else to_sort + ) + + return [_sort(group)[0] for group in group_samples(samples).values()] + + +def group_samples(samples, **kwargs): + """Groups `data.Sample`s with identical structures using `group_by_hash`.""" + return group_by_hash( + samples, hash_fn=lambda sample: hash_structure(sample.structure), **kwargs + ) + + +def group_by_hash(values, hash_fn=hash, store_index=False): + """Groups values by their hash value. + + Args: + values: An iterable of any objects. + hash_fn: A function that is called to compute the hash of values. + store_index: Whether to store the index of values or values in the returned + dict. + + Returns: + A `collections.OrderedDict` mapping hashes to values with that hash if + `store_index=False`, or value indices otherwise. The length of the map + corresponds to the number of unique hashes. + """ + groups = collections.OrderedDict() + for idx, value in enumerate(values): + to_store = idx if store_index else value + groups.setdefault(hash_fn(value), []).append(to_store) + return groups + + +def get_instance(instance_or_cls, **kwargs): + """Returns an instance given an instance or class reference. + + Enables passing both class references and class instances as (gin) configs. + + Args: + instance_or_cls: An instance of class or reference to a class. + **kwargs: Names arguments used for instantiation if `instance_or_cls` is a + class. + + Returns: + An instance of a class. + """ + if ( + inspect.isclass(instance_or_cls) + or inspect.isfunction(instance_or_cls) + or isinstance(instance_or_cls, functools.partial) + ): + return instance_or_cls(**kwargs) + else: + return instance_or_cls + + +def pd_option_context( + width=999, max_colwidth=999, max_rows=200, float_format="{:.3g}", **kwargs +): + """Returns a Pandas context manager with changed default arguments.""" + return pd.option_context( + "display.width", + width, + "display.max_colwidth", + max_colwidth, + "display.max_rows", + max_rows, + "display.float_format", + float_format.format, + **kwargs, + ) + + +def log_pandas(df_or_series, logger=logging.info, **kwargs): + """Logs a `pd.DataFrame` or `pd.Series`.""" + with pd_option_context(**kwargs): + for row in df_or_series.to_string().splitlines(): + logger(row) + + +def get_indices( + valid_indices, selection, map_negative=True, validate=False, exclude=False +): + """Maps a `selection` to `valid_indices` for indexing an iterable. + + Supports selecting indices as by + - a scalar: it[i] + - a list of scalars: it[[i, j, k]] + - a (list of) negative scalars: it[[i, -j, -k]] + - slices: it[slice(-3, None)] + + Args: + valid_indices: An iterable of valid indices that can be selected. + selection: A scalar, list of scalars, or `slice` for selecting indices. + map_negative: Whether to interpret `-i` as `len(valid_indices) + i`. + validate: Whether to raise an `IndexError` if `selection` is not + contained in `valid_indices`. + exclude: Whether to return all `valid_indices` except the selected ones. + + Raises: + IndexError: If `validate == True` and `selection` contains an index that is + not contained in `valid_indices`. + + Returns: + A list of indices. + """ + + def _raise_index_error(idx): + raise IndexError(f"Index {idx} invalid! Valid indices: {valid_indices}") + + if isinstance(selection, slice): + idxs = valid_indices[selection] + else: + idxs = [] + for idx in to_list(selection): + if map_negative and isinstance(idx, int) and idx < 0: + if abs(idx) <= len(valid_indices): + idxs.append(valid_indices[idx]) + elif validate: + _raise_index_error(idx) + elif idx in valid_indices: + idxs.append(idx) + elif validate: + _raise_index_error(idx) + if exclude: + idxs = [idx for idx in valid_indices if idx not in idxs] + return idxs diff --git a/src/poli_baselines/solvers/bayesian_optimization/turbo/turbo_wrapper.py b/src/poli_baselines/solvers/bayesian_optimization/turbo/turbo_wrapper.py index 81c7fac..6855f75 100644 --- a/src/poli_baselines/solvers/bayesian_optimization/turbo/turbo_wrapper.py +++ b/src/poli_baselines/solvers/bayesian_optimization/turbo/turbo_wrapper.py @@ -51,6 +51,14 @@ def __init__( The first row contains the lower bounds on x, the last row contains the upper bounds. """ super().__init__(black_box, x0, y0) + + assert x0.shape[0] > 1 + + assert bounds.shape[1] == 2 + assert bounds.shape[0] == x0.shape[1] + assert np.all(bounds[:, 1] >= bounds[:, 0]) + bounds[:, 1] -= bounds[:, 0] + assert x0.shape[0] > 1 if bounds is None: diff --git a/src/poli_baselines/tests/solvers/bayesian_optimization/test_amortized_bo.py b/src/poli_baselines/tests/solvers/bayesian_optimization/test_amortized_bo.py new file mode 100644 index 0000000..ca0defd --- /dev/null +++ b/src/poli_baselines/tests/solvers/bayesian_optimization/test_amortized_bo.py @@ -0,0 +1,26 @@ +"""This module tests the TURBO solver.""" + +import warnings +import numpy as np + +from poli_baselines.solvers.bayesian_optimization.amortized.amortized_bo_wrapper import ( + AmortizedBOWrapper, +) + +warnings.filterwarnings("ignore") + + +def test_amortized_bo_runs(): + from poli import objective_factory + + problem = objective_factory.create(name="aloha", observer_name=None) + black_box, x0 = problem.black_box, problem.x0 + y0 = black_box(x0) + + solver = AmortizedBOWrapper(black_box, x0, y0) + + solver.solve(max_iter=5) + + +if __name__ == "__main__": + test_amortized_bo_runs() diff --git a/src/poli_baselines/tests/solvers/bayesian_optimization/test_turbo.py b/src/poli_baselines/tests/solvers/bayesian_optimization/test_turbo.py index 2fa7eb6..43db1c5 100644 --- a/src/poli_baselines/tests/solvers/bayesian_optimization/test_turbo.py +++ b/src/poli_baselines/tests/solvers/bayesian_optimization/test_turbo.py @@ -18,12 +18,14 @@ def test_turbo_runs(): n_dimensions=10, ) black_box, x0 = problem.black_box, problem.x0 + x0 = np.concatenate([x0, np.random.rand(1, x0.shape[1])]) y0 = black_box(x0) - x0 = np.random.uniform(0, 1, size=20).reshape(2, 10) - y0 = black_box(x0) + bounds = np.concatenate( + [-np.ones([x0.shape[1], 1]), np.ones([x0.shape[1], 1])], axis=-1 + ) - solver = Turbo(black_box, x0, y0) + solver = Turbo(black_box, x0, y0, bounds=bounds) solver.solve(max_iter=5)