From e9d4b9e625742c7e2183f20fe11af50d476c54fb Mon Sep 17 00:00:00 2001 From: Simon Bartels Date: Mon, 27 May 2024 17:06:05 +0200 Subject: [PATCH 1/9] adds amortized bo wrapper --- .../amortized/__init__.py | 0 .../amortized/amortized_bo_wrapper.py | 19 + .../amortized/base_solver.py | 89 ++ .../bayesian_optimization/amortized/data.py | 685 ++++++++++++ .../amortized/deep_evolution_solver.py | 978 ++++++++++++++++++ .../amortized/domains.py | 272 +++++ .../amortized/requirements.txt | 0 .../bayesian_optimization/amortized/utils.py | 364 +++++++ .../test_amortized_bo.py | 32 + 9 files changed, 2439 insertions(+) create mode 100644 src/poli_baselines/solvers/bayesian_optimization/amortized/__init__.py create mode 100644 src/poli_baselines/solvers/bayesian_optimization/amortized/amortized_bo_wrapper.py create mode 100644 src/poli_baselines/solvers/bayesian_optimization/amortized/base_solver.py create mode 100644 src/poli_baselines/solvers/bayesian_optimization/amortized/data.py create mode 100644 src/poli_baselines/solvers/bayesian_optimization/amortized/deep_evolution_solver.py create mode 100644 src/poli_baselines/solvers/bayesian_optimization/amortized/domains.py create mode 100644 src/poli_baselines/solvers/bayesian_optimization/amortized/requirements.txt create mode 100644 src/poli_baselines/solvers/bayesian_optimization/amortized/utils.py create mode 100644 src/poli_baselines/tests/solvers/bayesian_optimization/test_amortized_bo.py diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/__init__.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/amortized_bo_wrapper.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/amortized_bo_wrapper.py new file mode 100644 index 0000000..646688c --- /dev/null +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/amortized_bo_wrapper.py @@ -0,0 +1,19 @@ +import numpy as np +from poli.core.abstract_black_box import AbstractBlackBox + +from poli_baselines.core.abstract_solver import AbstractSolver +from poli_baselines.solvers.bayesian_optimization.amortized.data import samples_from_arrays, Population +from poli_baselines.solvers.bayesian_optimization.amortized.deep_evolution_solver import MutationPredictorSolver +from poli_baselines.solvers.bayesian_optimization.amortized.domains import FixedLengthDiscreteDomain, Vocabulary + + +class AmortizedBOWrapper(AbstractSolver): + def __init__(self, black_box: AbstractBlackBox, x0: np.ndarray, y0: np.ndarray): + super().__init__(black_box, x0, y0) + self.domain = FixedLengthDiscreteDomain(vocab=Vocabulary(black_box.get_black_box_info().get_alphabet()), length=x0.shape[1]) + self.solver = MutationPredictorSolver(domain=self.domain, initialize_dataset_fn=lambda *args, **kwargs: self.domain.encode(x0)) + + def next_candidate(self) -> np.ndarray: + samples = samples_from_arrays(structures=self.domain.encode(self.x0.tolist()), rewards=self.y0.tolist()) + x = self.solver.propose(num_samples=1, population=Population(samples)) + return np.array([c for c in self.domain.decode(x)[0]])[np.newaxis, :] diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/base_solver.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/base_solver.py new file mode 100644 index 0000000..f065b35 --- /dev/null +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/base_solver.py @@ -0,0 +1,89 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Solver base class.""" + +import abc + +from absl import logging + +from poli_baselines.solvers.bayesian_optimization.amortized import utils + + +class BaseSolver(abc.ABC): + """Solver base class.""" + + def __init__(self, + domain, + random_state=None, + name=None, + log_level=logging.INFO, + **kwargs): + """Creates an instance of this class. + + Args: + domain: An instance of a `Domain`. + random_state: An instance of or integer seed to build a + `np.random.RandomState`. + name: The name of the solver. If `None`, will use the class name. + log_level: The logging level of the solver-specific logger. + -2=ERROR, -1=WARN, 0=INFO, 1=DEBUG. + **kwargs: Named arguments stored in `self.cfg`. + """ + self._domain = domain + self._name = name or self.__class__.__name__ + self._random_state = utils.get_random_state(random_state) + self._log = utils.get_logger(self._name, level=log_level) + + cfg = utils.Config(self._config()) + cfg.update(kwargs) + self.cfg = cfg + + def _config(self): + return {} + + @property + def domain(self): + """Return the optimization domain.""" + return self._domain + + @property + def name(self): + """Returns the solver name.""" + return self._name + + def __str__(self): + return self._name + + @abc.abstractmethod + def propose(self, + num_samples, + population=None, + pending_samples=None, + counter=0): + """Proposes num_samples from `self.domain`. + + Args: + num_samples: The number of samples to return. + population: A `Population` of samples or None if the population is empty. + pending_samples: A list of structures without reward that were already + proposed. + counter: The number of times `propose` has been called with the same + `population`. Can be used by solver to avoid repeated computations on + the same `population`, e.g. updating a model. + + Returns: + `num_samples` structures from the domain. + """ \ No newline at end of file diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/data.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/data.py new file mode 100644 index 0000000..f98d72f --- /dev/null +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/data.py @@ -0,0 +1,685 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Population class for keeping track of structures and rewards.""" + +import collections +import itertools +import typing + +import attr +import numpy as np +import pandas as pd +import tensorflow.compat.v1 as tf + +from poli_baselines.solvers.bayesian_optimization.amortized import domains +from poli_baselines.solvers.bayesian_optimization.amortized import utils + +# DatasetSample defines the structure of samples in tf.data.Datasets for +# pre-training solvers, whereas Samples (see below) defines the structure of +# samples in a Population object. +DatasetSample = collections.namedtuple('DatasetSample', ['structure', 'reward']) + + +def dataset_to_population(population_or_tf_dataset): + """Converts a TF dataset to a Population if it is not already a Population.""" + if isinstance(population_or_tf_dataset, Population): + return population_or_tf_dataset + else: + return Population.from_dataset(population_or_tf_dataset) + + +def serialize_structure(structure): + """Converts a structure to a string.""" + structure = np.asarray(structure) + dim = len(structure.shape) + if dim != 1: + raise NotImplementedError(f'`structure` must be 1d but is {dim}d!') + return domains.SEP_TOKEN.join(str(token) for token in structure) + + +def serialize_structures(structures, **kwargs): + """Converts a list of structures to a list of strings.""" + return [serialize_structure(structure, **kwargs) + for structure in structures] + + +def deserialize_structure(serialized_structure, dtype=np.int32): + """Converts a string to a structure. + + Args: + serialized_structure: A structure produced by `serialize_structure`. + dtype: The data type of the output numpy array. + + Returns: + A numpy array with `dtype`. + """ + return np.asarray( + [token for token in serialized_structure.split(domains.SEP_TOKEN)], + dtype=dtype) + + +def deserialize_structures(structures, **kwargs): + """Converts a list of strings to a list of structures. + + Args: + structures: A list of strings produced by `serialize_structures`. + **kwargs: Named arguments passed to `deserialize_structure`. + + Returns: + A list of numpy array. + """ + return [deserialize_structure(structure, **kwargs) + for structure in structures] + + +def serialize_population_frame(frame, inplace=False, domain=None): + """Serializes a population `pd.DataFrame` for representing it as plain text. + + Args: + frame: A `pd.DataFrame` produced by `Population.to_frame`. + inplace: Whether to serialize `frame` inplace instead of creating a copy. + domain: An optional domain for decoding structures. If provided, will + add a column `decoded_structure` with the serialized decoded structures. + + Returns: + A `pd.DataFrame` with serialized structures. + """ + if not inplace: + frame = frame.copy() + if domain: + frame['decoded_structure'] = serialize_structures( + domain.decode(frame['structure'], as_str=False)) + frame['structure'] = serialize_structures(frame['structure']) + return frame + + +def deserialize_population_frame(frame, inplace=False): + """Deserializes a population `pd.DataFrame` from plain text. + + Args: + frame: A `pd.DataFrame` produced by `serialize_population_frame`. + inplace: Whether to deserialize `frame` inplace instead of creating a copy. + + Returns: + A `pd.DataFrame` with deserialized structures. + """ + if not inplace: + frame = frame.copy() + frame['structure'] = deserialize_structures(frame['structure']) + if 'decoded_structure' in frame.columns: + frame['decoded_structure'] = deserialize_structures( + frame['decoded_structure'], dtype=str) + return frame + + +def population_frame_to_csv(frame, path_or_buf=None, domain=None, index=False, + **kwargs): + """Converts a population `pd.DataFrame` to a csv table. + + Args: + frame: A `pd.DataFrame` produced by `Population.to_frame`. + path_or_buf: File path or object. If `None`, the result is returned as a + string. Otherwise write the csv table to that file. + domain: A optional domain for decoding structures. + index: Whether to store the index of `frame`. + **kwargs: Named arguments passed to `frame.to_csv`. + + Returns: + If `path_or_buf` is `None`, returns the resulting csv format as a + string. Otherwise returns `None`. + """ + if frame.empty: + raise ValueError('Cannot write empty population frame to CSV file!') + frame = serialize_population_frame(frame, domain=domain) + return frame.to_csv(path_or_buf, index=index, **kwargs) + + +def population_frame_from_csv(path_or_buf, **kwargs): + """Reads a population `pd.DataFrame` from a file. + + Args: + path_or_buf: A string path of file buffer. + **kwargs: Named arguments passed to `pd.read_csv`. + + Returns: + A `pd.DataFrame`. + """ + frame = pd.read_csv(path_or_buf, dtype={'metadata': object}, **kwargs) + frame = deserialize_population_frame(frame) + return frame + + +def subtract_mean_batch_reward(population): + """Returns new Population where each batch has mean-zero rewards.""" + df = population.to_frame() + mean_dict = df.groupby('batch_index').reward.mean().to_dict() + + def reward_for_sample(sample): + return sample.reward - mean_dict[sample.batch_index] + + shifted_samples = [ + sample.copy(reward=reward_for_sample(sample)) for sample in population + ] + return Population(shifted_samples) + + +def _to_immutable_array(array): + to_return = np.array(array) + to_return.setflags(write=False) + return to_return + + +class _NumericConverter(object): + """Helper class for converting values to a numeric data type.""" + + def __init__(self, dtype, min_value=None, max_value=None): + self._dtype = dtype + self._min_value = min_value + self._max_value = max_value + + def __call__(self, value): + """Validates and converts `value` to `self._dtype`.""" + if value is None: + return value + if not np.isreal(value): + raise TypeError('%s is not numeric!' % value) + value = self._dtype(value) + if self._min_value is not None and value < self._min_value: + raise TypeError('%f < %f' % (value, self._min_value)) + if self._max_value is not None and value > self._max_value: + raise TypeError('%f > %f' % (value, self._max_value)) + return value + + +@attr.s( + frozen=True, # Make it immutable. + slots=True, # Improve memory overhead. + eq=False, # Because we override __eq__. +) +class Sample(object): + """Immutable container for a structure, reward, and additional data. + + Attributes: + key: (str) A unique identifier of the sample. If not provided, will create + a unique identifier using `utils.create_unique_id`. + structure: (np.ndarray) The structure. + reward: (float, optional) The reward. + batch_index: (int, optional) The batch index within the population. + infeasible: Whether the sample was marked as infeasible by the evaluator. + metadata: (any type, optional) Additional meta-data. + """ + structure: np.ndarray = attr.ib(factory=np.array, + converter=_to_immutable_array) # pytype: disable=wrong-arg-types # attr-stubs + reward: float = attr.ib( + converter=_NumericConverter(float), + default=None) + batch_index: int = attr.ib( + converter=_NumericConverter(int, min_value=0), + default=None) + infeasible: bool = attr.ib(default=False, + validator=attr.validators.instance_of(bool)) + key: str = attr.ib( + factory=utils.create_unique_id, + validator=attr.validators.optional(attr.validators.instance_of(str))) + metadata: typing.Dict[str, typing.Any] = attr.ib(default=None) + + def __eq__(self, other): + """Compares samples irrespective of their key.""" + return (self.reward == other.reward and + self.batch_index == other.batch_index and + self.infeasible == other.infeasible and + self.metadata == other.metadata and + np.array_equal(self.structure, other.structure)) + + def equal(self, other): + """Compares samples including their key.""" + return self == other and self.key == other.key + + def to_tfexample(self): + """Converts a Sample to a tf.Example.""" + features = dict( + structure=tf.train.Feature( + int64_list=tf.train.Int64List(value=self.structure)), + reward=tf.train.Feature( + float_list=tf.train.FloatList(value=[self.reward])), + batch_index=tf.train.Feature( + int64_list=tf.train.Int64List(value=[self.batch_index]))) + return tf.train.Example(features=tf.train.Features(feature=features)) + + def to_dict(self, dict_factory=collections.OrderedDict): + """Returns a dictionary from field names to values. + + Args: + dict_factory: A class that implements a dict factory method. + + Returns: + A dict of type `dict_factory` + """ + return attr.asdict(self, dict_factory=dict_factory) + + def copy(self, new_key=True, **kwargs): + """Returns a copy of this Sample with values overridden by **kwargs.""" + if new_key: + kwargs['key'] = utils.create_unique_id() + return attr.evolve(self, **kwargs) + + +def samples_from_arrays(structures, rewards=None, batch_index=None, + metadata=None): + """Makes a generator of Samples from fields. + + Args: + structures: Iterable of structures (1-D np array or list). + rewards: Iterable of float rewards. If None, the corresponding Samples are + given each given a reward of None. + batch_index: Either an int, in which case all Samples created by this + function will be given this batch_index or an iterable of ints for each + corresponding structure. + metadata: Metadata to store in the Sample. + + Yields: + A generator of Samples + """ + structures = utils.to_array(structures) + + if metadata is None: + metadata = [None] * len(structures) + + if rewards is None: + rewards = [None] * len(structures) + else: + rewards = utils.to_array(rewards) + + if len(structures) != len(rewards): + raise ValueError( + 'Structures and rewards must be same length. Are %s and %s' % + (len(structures), len(rewards))) + if len(metadata) != len(rewards): + raise ValueError('Metadata and rewards must be same length. Are %s and %s' % + (len(metadata), len(rewards))) + + if batch_index is None: + batch_index = 0 + if isinstance(batch_index, int): + batch_index = [batch_index] * len(structures) + + for structure, reward, batch_index, meta in zip(structures, rewards, + batch_index, metadata): + yield Sample( + structure=structure, + reward=reward, + batch_index=batch_index, + metadata=meta) + + +def parse_tf_example(example_proto): + """Converts tf.Example proto to dict of Tensors. + + Args: + example_proto: A raw tf.Example proto. + Returns: + A dict of Tensors with fields structure, reward, and batch_index. + """ + + feature_description = dict( + structure=tf.FixedLenSequenceFeature((), tf.int64, allow_missing=True), + reward=tf.FixedLenFeature([1], tf.float32), + batch_index=tf.FixedLenFeature([1], tf.int64)) + + return tf.io.parse_single_example( + serialized=example_proto, features=feature_description) + + +class Population(object): + """Data structure for storing Samples.""" + + def __init__(self, samples=None): + """Construct a Population. + + Args: + samples: An iterable of Samples + """ + self._samples = collections.OrderedDict() + self._batch_to_sample_keys = collections.defaultdict(list) + + if samples is not None: + self.add_samples(samples) + + def __str__(self): + if self.empty: + return '' + return ('' % + (len(self), dict(self.best().to_dict()))) + + def __len__(self): + return len(self._samples) + + def __iter__(self): + return self._samples.values().__iter__() + + def __eq__(self, other): + if not isinstance(other, Population): + raise ValueError('Cannot compare equality with an object of ' + 'type %s' % (str(type(other)))) + + return len(self) == len(other) and all( + s1 == s2 for s1, s2 in zip(self, other)) + + def __add__(self, other): + """Adds samples to this population and returns a new Population.""" + return Population(itertools.chain(self, other)) + + def __getitem__(self, key_or_index): + if isinstance(key_or_index, str): + return self._samples[key_or_index] + else: + return list(self._samples.values())[key_or_index] + + def __contains__(self, key): + if isinstance(key, Sample): + key = key.key + return key in self._samples + + def copy(self): + """Copies the population.""" + return Population(self.samples) + + @property + def samples(self): + """Returns the population Samples as a list.""" + return list(self._samples.values()) + + def add_sample(self, sample): + """Add copy of sample to population.""" + if sample.key in self._samples: + raise ValueError('Sample with key %s already exists in the population!') + self._samples[sample.key] = sample + batch_idx = sample.batch_index + self._batch_to_sample_keys[batch_idx].append(sample.key) + + def add_samples(self, samples): + """Convenience method for adding multiple samples.""" + for sample in samples: + self.add_sample(sample) + + @property + def empty(self): + return not self + + @property + def batch_indices(self): + """Returns a sorted list of unique batch indices.""" + return sorted(self._batch_to_sample_keys) + + @property + def max_batch_index(self): + """Returns the maximum batch index.""" + if self.empty: + raise ValueError('Population empty!') + return self.batch_indices[-1] + + @property + def current_batch_index(self): + """Return the maximum batch index or -1 if the population is empty.""" + return -1 if self.empty else self.max_batch_index + + def get_batches(self, batch_indices, exclude=False, validate=True): + """"Extracts certain batches from the population. + + Ignores batches that do not exist in the population. To validate if a + certain batch exists use `batch_index in population.batch_indices`. + + Args: + batch_indices: An integer, iterable of integers, or `slice` object + for selecting batches. + exclude: If true, will return all batches but the ones that are selected. + validate: Whether to raise an exception if a batch index is invalid + instead of ignoring. + + Returns: + A `Population` with the selected batches. + """ + batch_indices = utils.get_indices( + self.batch_indices, batch_indices, exclude=exclude, validate=validate) + sample_idxs = [] + for batch_index in batch_indices: + sample_idxs.extend(self._batch_to_sample_keys.get(batch_index, [])) + samples = [self[idx] for idx in sample_idxs] + return Population(samples) + + def get_batch(self, *args, **kwargs): + return self.get_batches(*args, **kwargs) + + def get_last_batch(self, **kwargs): + """Returns the last batch from the population.""" + return self.get_batch(-1, **kwargs) + + def get_last_batches(self, n=1, **kwargs): + """Selects the last n batches.""" + return self.get_batches(self.batch_indices[-n:], **kwargs) + + def to_structures_and_rewards(self): + """Return (list of structures, list of rewards) in the Population.""" + structures = [] + rewards = [] + for sample in self: + structures.append(sample.structure) + rewards.append(sample.reward) + return structures, rewards + + @property + def structures(self): + """Returns the structure of all samples in the population.""" + return [sample.structure for sample in self] + + @property + def rewards(self): + """Returns the reward of all samples in the population.""" + return [sample.reward for sample in self] + + @staticmethod + def from_arrays(structures, rewards=None, batch_index=0, metadata=None): + """Creates Population from the specified fields. + + Args: + structures: Iterable of structures (1-D np array or list). + rewards: Iterable of float rewards. If None, the corresponding Samples are + given each given a reward of None. + batch_index: Either an int, in which case all Samples created by this + function will be given this batch_index or an iterable of ints for each + corresponding structure. + metadata: Metadata to store in the Samples. + + Returns: + A Population. + """ + samples = samples_from_arrays(structures, rewards, batch_index, metadata) + return Population(samples) + + def add_batch(self, structures, rewards=None, batch_index=None, + metadata=None): + """Adds a batch of samples to the Population. + + Args: + structures: Iterable of structures (1-D np array or list). + rewards: Iterable of rewards. + batch_index: Either an int, in which case all Samples created by this + function will be given this batch_index or an iterable of ints for each + corresponding structure. If `None`, uses `self.current_batch + 1`. + metadata: Metadata to store in the Samples. + """ + if batch_index is None: + batch_index = self.current_batch_index + 1 + samples = samples_from_arrays(structures, rewards, batch_index, metadata) + self.add_samples(samples) + + def head(self, n): + """Returns new Population containing first n samples.""" + + return Population(self.samples[:n]) + + @staticmethod + def from_dataset(dataset): + """Converts dataset of DatasetSample to Population.""" + samples = utils.arrays_from_dataset(dataset) + return Population.from_arrays(samples.structure, samples.reward) + + def to_dataset(self): + """Converts the population to a `tf.data.Dataset` with `DatasetSample`s.""" + structures, rewards = self.to_structures_and_rewards() + return utils.dataset_from_tensors( + DatasetSample(structure=structures, reward=rewards)) + + @staticmethod + def from_tfrecord(filename): + """Reads Population from tfrecord file.""" + raw_dataset = tf.data.TFRecordDataset([filename]) + parsed_dataset = raw_dataset.map(parse_tf_example) + + def _record_to_dict(record): + mapping = {key: utils.to_array(value) for key, value in record.items()} + if 'batch_index' in mapping: + mapping['batch_index'] = int(mapping['batch_index']) + return mapping + + return Population(Sample(**_record_to_dict(r)) for r in parsed_dataset) + + def to_tfrecord(self, filename): + """Writes Population to tfrecord file.""" + + with tf.python_io.TFRecordWriter(filename) as writer: + for sample in self.samples: + writer.write(sample.to_tfexample().SerializeToString()) + + def to_frame(self): + """Converts a `Population` to a `pd.DataFrame`.""" + records = [sample.to_dict() for sample in self.samples] + return pd.DataFrame.from_records(records) + + @classmethod + def from_frame(cls, frame): + """Converts a `pd.DataFrame` to a `Population`.""" + if frame.empty: + return cls() + population = cls() + for _, row in frame.iterrows(): + sample = Sample(**row.to_dict()) + population.add_sample(sample) + return population + + def to_csv(self, path, domain=None): + """Stores a population to a CSV file. + + Args: + path: The output file path. + domain: An optional `domains.Domain`. If provided, will also store + decoded structures in the CSV file. + """ + population_frame_to_csv(self.to_frame(), path, domain=domain) + + @classmethod + def from_csv(cls, path): + """Restores a population from a CSV file. + + Args: + path: The CSV file path. + + Returns: + An instance of a `Population`. + """ + return cls.from_frame(population_frame_from_csv(path).drop( + columns=['decoded_structure'], errors='ignore')) + + def best_n(self, n=1, q=None, discard_duplicates=False, blacklist=None): + """Returns the best n samples. + + Note that ties are broken deterministically. + + Args: + n: Max number to return + q: A float in (0, 1) corresponding to the minimum quantile for selecting + samples. If provided, `n` is ignored and samples with a reward >= + this quantile are selected. + discard_duplicates: If True, when several samples have the same structure, + return only one of them (the selected one is unspecified). + blacklist: Iterable of structures that should be excluded. + + Returns: + Population containing the best n Samples, sorted in decreasing order of + reward (output[0] is the best). Returns less than n if there are fewer + than n Samples in the population. + """ + if self.empty: + raise ValueError('Population empty.') + + samples = self.samples + if blacklist: + samples = self._filter(samples, blacklist) + + # are unique. + if discard_duplicates and len(samples) > 1: + samples = utils.deduplicate_samples(samples) + + samples = sorted(samples, key=lambda sample: sample.reward, reverse=True) + if q is not None: + q_value = np.quantile([sample.reward for sample in samples], q) + return Population( + [sample for sample in samples if sample.reward >= q_value]) + else: + return Population(samples[:n]) + + def best(self, blacklist=None): + return self.best_n(1, blacklist=blacklist).samples[0] + + def _filter(self, samples, blacklist): + blacklist = set(utils.hash_structure(structure) for structure in blacklist) + return Population( + sample for sample in samples + if utils.hash_structure(sample.structure) not in blacklist) + + def contains_structure(self, structure): + return self.contains_structures([structure])[0] + + # TODO(dbelanger): this has computational overhead because it hashes the + # entire population. If this function is being called often, the set of hashed + # structures should be built up incrementally as elements are added to the + # population. Right now, the function is rarely called. + # TODO(ddohan): Build this just in time: track what hasn't been hashed + # since last call, and add any new structures before checking contain. + def contains_structures(self, structures): + structures_in_population = set( + utils.hash_structure(sample.structure) for sample in self.samples) + return [ + utils.hash_structure(structure) in structures_in_population + for structure in structures + ] + + def deduplicate(self, select_best=False): + """De-duplicates Samples with identical structures. + + Args: + select_best: Whether to select the sample with the highest reward among + samples with the same structure. Otherwise, the sample that occurs + first will be selected. + + Returns: + A Population with de-duplicated samples. + """ + return Population(utils.deduplicate_samples(self, select_best=select_best)) + + def discard_infeasible(self): + """Returns a new `Population` with all infeasible samples removed.""" + return Population(sample for sample in self if not sample.infeasible) \ No newline at end of file diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/deep_evolution_solver.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/deep_evolution_solver.py new file mode 100644 index 0000000..cdf13de --- /dev/null +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/deep_evolution_solver.py @@ -0,0 +1,978 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Deep Evolution solver. + +Uses a neural net to predict the position and +mutation function to apply to a string. Neural net takes the string and predicts +1) [Batch x length] logits over positions in the string +2) [Batch x length x n_mutations] logits over mutation function for every +position in the string. +First, we sample the position from the position logits, take the logits +corresponding to the chosen position and sample the index of the mutation +function to apply to this position in the string. Currently, we apply +one mutation at a time. Finally, update the network parameters using REINFORCE +gradient estimator, where the advantage is the difference between parent and +child rewards. The log-likelihood is the sum of position and mutation +log-likelihoods. +By default, no selection is performed (we continue mutating the same batch, +use_selection_of_best = False). If use_selection_of_best=True, we choose best +samples from the previous batch and sample them with replacement to create +a new batch. +""" +import functools + +#from absl import logging +#import gin +import jax +from jax.example_libraries import stax +from jax.example_libraries.optimizers import adam +import jax.numpy as jnp +import jax.random as jrand +from jax.scipy.special import logsumexp +import numpy as np + +from poli_baselines.solvers.bayesian_optimization.amortized import base_solver +from poli_baselines.solvers.bayesian_optimization.amortized import data +from poli_baselines.solvers.bayesian_optimization.amortized import utils + + +def logsoftmax(x, axis=-1): + """Apply log softmax to an array of logits, log-normalizing along an axis.""" + return x - logsumexp(x, axis, keepdims=True) + + +def softmax(x, axis=-1): + return jnp.exp(logsoftmax(x, axis)) + + +def one_hot(x, k): + """Create a one-hot encoding of x of size k.""" + return jnp.eye(k)[x] + + +def gumbel_max_sampler(logits, temperature, rng): + """Sample fom categorical distribution using Gumbel-Max trick. + + Gumbel-Max trick: + https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ + https://arxiv.org/abs/1411.0030 + + Args: + logits: Unnormalized logits for categorical distribution. + [batch x n_mutations_to_sample x n_mutation_types] + temperature: temperature parameter for Gumbel-Max. The lower the + temperature, the closer the sample is to one-hot-encoding. + rng: Jax random number generator + + Returns: + class_assignments: Sampled class assignments [batch] + log_likelihoods: Log-likelihoods of the sampled mutations [batch] + """ + + # Normalize the logits + logits = logsoftmax(logits) + + gumbel_noise = jrand.gumbel(rng, logits.shape) + softmax_logits = (logits + gumbel_noise) / temperature + soft_assignments = softmax(softmax_logits, -1) + class_assignments = jnp.argmax(soft_assignments, -1) + assert len(class_assignments.shape) == 2 + # Output shape: [batch x num_mutations] + + return class_assignments + + +########################################## +# Mutation-related helper functions +def _mutate_position(structure, pos_mask, fn): + """Apply mutation fn to position specified by pos_mask.""" + structure = np.array(structure).copy() + pos_mask = np.array(pos_mask).astype(int) + structure[pos_mask == 1] = fn(structure[pos_mask == 1]) + return structure + + +def set_pos(x, pos_mask, val): + return _mutate_position(x, pos_mask, fn=lambda x: val) + + +def apply_mutations(samples, mutation_types, pos_masks, mutations, + use_assignment_mutations=False): + """Apply the mutations specified by mutation types to the batch of strings. + + Args: + samples: Batch of strings [batch x str_length] + mutation_types: IDs of mutation types to be applied to each string + [Batch x num_mutations] + pos_masks: One-hot encoding [Batch x num_mutations x str_length] + of the positions to be mutate in each string. + "num_mutations" positions will be mutated per string. + mutations: A list of possible mutation functions. + Functions should follow the format: fn(x, domain, pos_mask), + use_assignment_mutations: bool. Whether mutations are defined as + "Set position X to character C". If use_assignment_mutations=True, + then vectorize procedure of applying mutations to the string. + The index of mutation type should be equal to the index of the character. + Gives considerable speed-up to this function. + + Returns: + perturbed_samples: Strings perturbed according to the mutation list. + """ + batch_size = samples.shape[0] + assert len(mutation_types) == batch_size + assert len(pos_masks) == batch_size + + str_length = samples.shape[1] + assert pos_masks.shape[-1] == str_length + + # Check that number of mutations is consistent in mutation_types and positions + assert mutation_types.shape[1] == pos_masks.shape[1] + + num_mutations = mutation_types.shape[1] + + # List of batched samples with 0,1,2,... mutations + # First element of the list contains original samples + # Last element has samples with all mutations applied to the string + perturbed_samples_with_i_mutations = [samples] + for i in range(num_mutations): + + perturbed_samples = [] + samples_to_perturb = perturbed_samples_with_i_mutations[-1] + + if use_assignment_mutations: + perturbed_samples = samples_to_perturb.copy() + mask = pos_masks[:, i].astype(int) + # Assumes mutations are defined as "Set position to the character C" + perturbed_samples[np.array(mask) == 1] = mutation_types[:, i] + else: + for j in range(batch_size): + sample = samples_to_perturb[j].copy() + + pos = pos_masks[j, i] + mut_id = mutation_types[j, i] + + mutation = mutations[int(mut_id)] + perturbed_samples.append(mutation(sample, pos)) + perturbed_samples = np.stack(perturbed_samples) + + assert perturbed_samples.shape == samples.shape + perturbed_samples_with_i_mutations.append(perturbed_samples) + + states = jnp.stack(perturbed_samples_with_i_mutations, 0) + assert states.shape == (num_mutations + 1,) + samples.shape + return states + + +########################################## +# pylint: disable=invalid-name +def OneHot(depth): + """Layer for transforming inputs to one-hot encoding.""" + + def init_fun(rng, input_shape): + del rng + return input_shape + (depth,), () + + def apply_fun(params, inputs, **kwargs): + del params, kwargs + # Perform one-hot encoding + return jnp.eye(depth)[inputs.astype(int)] + + return init_fun, apply_fun + + +def ExpandDims(axis=1): + """Layer for expanding dimensions.""" + def init_fun(rng, input_shape): + del rng + input_shape = tuple(input_shape) + if axis < 0: + dims = len(input_shape) + new_axis = dims + 1 - axis + else: + new_axis = axis + return (input_shape[:new_axis] + (1,) + input_shape[new_axis:]), () + def apply_fun(params, inputs, **kwargs): + del params, kwargs + return jnp.expand_dims(inputs, axis) + return init_fun, apply_fun + + +def AssertNonZeroShape(): + """Layer for checking that no dimension has zero length.""" + + def init_fun(rng, input_shape): + del rng + return input_shape, () + + def apply_fun(params, inputs, **kwargs): + del params, kwargs + assert 0 not in inputs.shape + return inputs + + return init_fun, apply_fun + +# pylint: enable=invalid-name + + +def squeeze_layer(axis=1): + """Layer for squeezing dimension along the axis.""" + def init_fun(rng, input_shape): + del rng + if axis < 0: + raise ValueError("squeeze_layer: negative axis is not supported") + return (input_shape[:axis] + input_shape[(axis + 1):]), () + + def apply_fun(params, inputs, **kwargs): + del params, kwargs + return inputs.squeeze(axis) + + return init_fun, apply_fun + + +def reduce_layer(reduce_fn=jnp.mean, axis=1): + """Apply reduction function to the array along axis.""" + def init_fun(rng, input_shape): + del rng + assert axis >= 0 + assert len(input_shape) == 3 + return input_shape[:axis - 1] + input_shape[axis + 1:], () + + def apply_fun(params, inputs, **kwargs): + del params, kwargs + return reduce_fn(inputs, axis=axis) + + return init_fun, apply_fun + + +def _create_positional_encoding( # pylint: disable=invalid-name + input_shape, max_len=10000): + """Helper: create positional encoding parameters.""" + d_feature = input_shape[-1] + pe = np.zeros((max_len, d_feature), dtype=np.float32) + position = np.arange(0, max_len)[:, np.newaxis] + div_term = np.exp(np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature)) + pe[:, 0::2] = np.sin(position * div_term) + pe[:, 1::2] = np.cos(position * div_term) + pe = pe[np.newaxis, :, :] # [1, max_len, d_feature] + return jnp.array(pe) # These are trainable parameters, initialized as above. + + +def positional_encoding(): + """Concatenate positional encoding to the last dimension.""" + def init_fun(rng, input_shape): + del rng + input_shape_for_enc = input_shape + params = _create_positional_encoding(input_shape_for_enc) + last_dim = input_shape[-1] + params.shape[-1] + return input_shape[:-1] + (last_dim,), (params,) + + def apply_fun(params, inputs, **kwargs): + del kwargs + assert inputs.ndim == 4 + params = params[0] + symbol_size = inputs.shape[-2] + enc = params[None, :, :symbol_size, :] + enc = jnp.repeat(enc, inputs.shape[0], 0) + return jnp.concatenate((inputs, enc), -1) + + return init_fun, apply_fun + + +def cnn(conv_depth=300, + kernel_size=5, + n_conv_layers=2, + across_batch=False, + add_pos_encoding=False): + """Build convolutional neural net.""" + # Input shape: [batch x length x depth] + if across_batch: + extra_dim = 0 + else: + extra_dim = 1 + layers = [ExpandDims(axis=extra_dim)] + if add_pos_encoding: + layers.append(positional_encoding()) + + for _ in range(n_conv_layers): + layers.append( + stax.Conv(conv_depth, (1, kernel_size), padding="same", strides=(1, 1))) + layers.append(stax.Relu) + layers.append(AssertNonZeroShape()) + layers.append(squeeze_layer(axis=extra_dim)) + return stax.serial(*layers) + + +def build_model_stax(output_size, + n_dense_units=300, + conv_depth=300, + n_conv_layers=2, + n_dense_layers=0, + kernel_size=5, + across_batch=False, + add_pos_encoding=False, + mean_over_pos=False, + mode="train"): + """Build a model with convolutional layers followed by dense layers.""" + del mode + layers = [ + cnn(conv_depth=conv_depth, + n_conv_layers=n_conv_layers, + kernel_size=kernel_size, + across_batch=across_batch, + add_pos_encoding=add_pos_encoding) + ] + for _ in range(n_dense_layers): + layers.append(stax.Dense(n_dense_units)) + layers.append(stax.Relu) + + layers.append(stax.Dense(output_size)) + + if mean_over_pos: + layers.append(reduce_layer(jnp.mean, axis=1)) + init_random_params, predict = stax.serial(*layers) + return init_random_params, predict + + +def sample_log_probs_top_k(log_probs, rng, temperature=1., k=1): + """Sample categorical distribution of log probs using gumbel max trick.""" + noise = jax.random.gumbel(rng, shape=log_probs.shape) + perturbed = (log_probs + noise) / temperature + samples = jnp.argsort(perturbed)[Ellipsis, -k:] + return samples + + +@jax.jit +def gather_positions(idx_to_gather, logits): + """Collect logits corresponding to the positions in the string. + + Used for collecting logits for: + 1) positions in the string (depth = 1) + 2) mutation types (depth = n_mut_types) + + Args: + idx_to_gather: [batch_size x num_mutations] Indices of the positions + in the string to gather logits for. + logits: [batch_size x str_length x depth] Logits to index. + + Returns: + Logits corresponding to the specified positions in the string: + [batch_size, num_mutations, depth] + """ + assert idx_to_gather.shape[0] == logits.shape[0] + assert idx_to_gather.ndim == 2 + assert logits.ndim == 3 + + batch_size, num_mutations = idx_to_gather.shape + batch_size, str_length, depth = logits.shape + + oh = one_hot(idx_to_gather, str_length) + assert oh.shape == (batch_size, num_mutations, str_length) + + oh = oh[Ellipsis, None] + logits = logits[:, None, :, :] + assert oh.shape == (batch_size, num_mutations, str_length, 1) + assert logits.shape == (batch_size, 1, str_length, depth) + + # Perform element-wise multiplication (with broadcasting), + # then sum over str_length dimension + result = jnp.sum(oh * logits, axis=-2) + assert result.shape == (batch_size, num_mutations, depth) + return result + + +class JaxMutationPredictor(object): + """Implements training and predicting from a Jax model. + + Attributes: + output_size: Tuple containing the sizes of components to predict + loss_fn: Loss function. + Format of the loss fn: fn(params, batch, mutations, problem, predictor) + loss_grad_fn: Gradient of the loss function + temperature: temperature parameter for Gumbel-Max sampler. + learning_rate: Learning rate for optimizer. + batch_size: Batch size of input + model_fn: Function which builds the model forward pass. Must have arguments + `vocab_size`, `max_len`, and `mode` and return Jax float arrays. + params: weights of the neural net + make_state: function to make optimizer state given the network parameters + rng: Jax random number generator + """ + + def __init__(self, + vocab_size, + output_size, + loss_fn, + rng, + temperature=1, + learning_rate=0.001, + conv_depth=300, + n_conv_layers=2, + n_dense_units=300, + n_dense_layers=0, + kernel_size=5, + across_batch=False, + add_pos_encoding=False, + mean_over_pos=False, + model_fn=build_model_stax): + self.output_size = output_size + self.temperature = temperature + + # Setup randomness. + self.rng = rng + + model_settings = { + "output_size": output_size, + "n_dense_units": n_dense_units, + "n_dense_layers": n_dense_layers, + "conv_depth": conv_depth, + "n_conv_layers": n_conv_layers, + "across_batch": across_batch, + "kernel_size": kernel_size, + "add_pos_encoding": add_pos_encoding, + "mean_over_pos": mean_over_pos, + "mode": "train" + } + + self._model_init, model_train = model_fn(**model_settings) + self._model_train = jax.jit(model_train) + + model_settings["mode"] = "eval" + _, model_predict = model_fn(**model_settings) + self._model_predict = jax.jit(model_predict) + + self.rng, subrng = jrand.split(self.rng) + _, init_params = self._model_init(subrng, (-1, -1, vocab_size)) + self.params = init_params + + # Setup parameters for model and optimizer + self.make_state, self._opt_update_state, self._get_params = adam( + learning_rate) + + self.loss_fn = functools.partial(loss_fn, run_model_fn=self.run_model) + self.loss_grad_fn = jax.grad(self.loss_fn) + + # Track steps of optimization so far. + self._step_idx = 0 + + def update_step(self, rewards, inputs, actions): + """Performs a single update step on a batch of samples. + + Args: + rewards: Batch [batch] of rewards for perturbed samples. + inputs: Batch [batch x length] of original samples + actions: actions applied on the samples + + Raises: + ValueError: if any inputs are the wrong shape. + """ + + grad_update = self.loss_grad_fn( + self.params, + rewards=rewards, + inputs=inputs, + actions=actions, + ) + + old_params = self.params + state = self.make_state(old_params) + state = self._opt_update_state(self._step_idx, grad_update, state) + + self.params = self._get_params(state) + del old_params, state + + self._step_idx += 1 + + def __call__(self, x, mode="eval"): + """Calls predict function of model. + + Args: + x: Batch of input samples. + mode: Mode for running the network: "train" or "eval" + + Returns: + A list of tuples (class weights, log likelihood) for each of + output components predicted by the model. + """ + return self.run_model(x, self.params, mode="eval") + + def run_model(self, x, params, mode="eval"): + """Run the Jax model. + + This function is used in __call__ to run the model in "eval" mode + and in the loss function to run the model in "train" mode. + + Args: + x: Batch of input samples. + params: Network parameters + mode: Mode for running the network: "train" or "eval" + + Returns: + Jax neural network output. + """ + if mode == "train": + model_fn = self._model_train + else: + model_fn = self._model_predict + self.rng, subrng = jax.random.split(self.rng) + return model_fn(params, inputs=x, rng=subrng) + + +######################################### +# Loss function +def reinforce_loss(rewards, log_likelihood): + """Loss function for Jax model. + + Args: + rewards: List of rewards [batch] for the perturbed samples. + log_likelihood: Log-likelihood of perturbations + + Returns: + Scalar loss. + """ + rewards = jax.lax.stop_gradient(rewards) + + # In general case, we assume that the loss is not differentiable + # Use REINFORCE + reinforce_estim = rewards * log_likelihood + # Take mean over the number of applied mutations, then across the batch + return -jnp.mean(jnp.mean(reinforce_estim, 1), 0) + + +def compute_entropy(log_probs): + """Compute entropy of a set of log_probs.""" + return -jnp.mean(jnp.mean(stax.softmax(log_probs) * log_probs, axis=-1)) + + +def compute_advantage(params, critic_fn, rewards, inputs): + """Compute the advantage: difference between rewards and predicted value. + + Args: + params: parameters for the critic neural net + critic_fn: function to run critic neural net + rewards: rewards for the perturbed samples + inputs: original samples, used as input to the Jax model + + Returns: + advantage: [batch_size x num_mutations] + """ + assert inputs.ndim == 4 + + num_mutations, batch_size, str_length, vocab_size = inputs.shape + + inputs_reshaped = inputs.reshape( + (num_mutations * batch_size, str_length, vocab_size)) + + predicted_value = critic_fn(inputs_reshaped, params, mode="train") + assert predicted_value.shape == (num_mutations * batch_size, 1) + predicted_value = predicted_value.reshape((num_mutations, batch_size)) + + assert rewards.shape == (batch_size,) + rewards = jnp.repeat(rewards[None, :], num_mutations, 0) + assert rewards.shape == (num_mutations, batch_size) + + advantage = rewards - predicted_value + advantage = jnp.transpose(advantage) + assert advantage.shape == (batch_size, num_mutations) + return advantage + + +def value_loss_fn(params, run_model_fn, rewards, inputs, actions=None): + """Compute the loss for the value function. + + Args: + params: parameters for the Jax model + run_model_fn: Jax model to run + rewards: rewards for the perturbed samples + inputs: original samples, used as input to the Jax model + actions: not used + + Returns: + A scalar loss. + """ + del actions + advantage = compute_advantage(params, run_model_fn, rewards, inputs) + advantage = advantage**2 + + return jnp.sqrt(jnp.mean(advantage)) + + +def split_mutation_predictor_output(output): + return stax.logsoftmax(output[:, :, -1]), stax.logsoftmax(output[:, :, :-1]) + + +def run_model_and_compute_reinforce_loss(params, + run_model_fn, + rewards, + inputs, + actions, + n_mutations, + entropy_weight=0.1): + """Run Jax model and compute REINFORCE loss. + + Jax can compute the gradients of the model only if the model is called inside + the loss function. Here we call the Jax model, re-compute the log-likelihoods, + take log-likelihoods of the mutations and positions sampled before in + _propose function of the solver, and compute the loss. + + Args: + params: parameters for the Jax model + run_model_fn: Jax model to run + rewards: rewards for the perturbed samples + inputs: original samples, used as input to the Jax model + actions: Tuple (mut_types [Batch], positions [Batch]) of mutation types + and positions sampled during the _propose() step of evolution solver. + n_mutations: Number of mutations. Used for one-hot encoding of mutations + entropy_weight: Weight on the entropy term added to the loss. + + Returns: + A scalar loss. + + """ + mut_types, positions = actions + mut_types_one_hot = one_hot(mut_types, n_mutations) + + batch_size, str_length, _ = inputs.shape + assert mut_types.shape[0] == inputs.shape[0] + batch_size, num_mutations = mut_types.shape + assert mut_types.shape == positions.shape + assert mut_types.shape == rewards.shape + + output = run_model_fn(inputs, params, mode="train") + pos_log_probs, all_mut_log_probs = split_mutation_predictor_output(output) + assert pos_log_probs.shape == (batch_size, str_length) + pos_log_probs = jnp.expand_dims(pos_log_probs, -1) + + pos_log_likelihoods = gather_positions(positions, pos_log_probs) + assert pos_log_likelihoods.shape == (batch_size, num_mutations, 1) + + # Sum over number of positions + pos_log_likelihoods = jnp.sum(pos_log_likelihoods, -1) + + # all_mut_log_probs shape: [batch_size, str_length, n_mut_types] + assert all_mut_log_probs.shape[:2] == (batch_size, str_length) + + # Get mutation logits corresponding to the chosen positions + mutation_logprobs = gather_positions(positions, all_mut_log_probs) + + # Get log probs corresponding to the selected mutations + mut_log_likelihoods_oh = mutation_logprobs * mut_types_one_hot + + # Sum over mutation types + mut_log_likelihoods = jnp.sum(mut_log_likelihoods_oh, -1) + assert mut_log_likelihoods.shape == (batch_size, num_mutations) + + joint_log_likelihood = mut_log_likelihoods + pos_log_likelihoods + assert joint_log_likelihood.shape == (batch_size, num_mutations) + + loss = reinforce_loss(rewards, joint_log_likelihood) + loss -= entropy_weight * compute_entropy(mutation_logprobs) + return loss + + +############################################ +# MutationPredictorSolver +def initialize_uniformly(domain, batch_size, random_state): + return domain.sample_uniformly(batch_size, seed=random_state) + + +#@gin.configurable +class MutationPredictorSolver(base_solver.BaseSolver): + """Choose the mutation operator conditioned on the sample. + + Sample from categorical distribution over available mutation operators + using Gumbel-Max trick + """ + + def __init__(self, + domain, + model_fn=build_model_stax, + random_state=0, + **kwargs): + """Constructs solver. + + Args: + domain: discrete domain + model_fn: Function which builds the forward pass of predictor model. + random_state: Random state to initialize jax & np RNGs. + **kwargs: kwargs passed to config. + """ + super(MutationPredictorSolver, self).__init__( + domain=domain, random_state=random_state, **kwargs) + self.rng = jrand.PRNGKey(random_state) + self.rng, rng = jax.random.split(self.rng) + + if self.domain.length < self.cfg.num_mutations: + # logging.warning("Number of mutations to perform per string exceeds string" + # " length. The number of mutation is set to be equal to " + # "the string length.") + self.cfg.num_mutations = self.domain.length + + # Right now the mutations are defined as "Set position X to character C". + # It allows to vectorize applying mutations to the string and speeds up + # the solver. + # If using other types of mutations, set self.use_assignment_mut=False. + self.mutations = [] + for val in range(self.domain.vocab_size): + self.mutations.append(functools.partial(set_pos, val=val)) + self.use_assignment_mut = True + + mut_loss_fn = functools.partial(run_model_and_compute_reinforce_loss, + n_mutations=len(self.mutations)) + + # Predictor that takes the input string + # Outputs the weights over the 1) mutations types 2) position in string + if self.cfg.pretrained_model is None: + self._mut_predictor = self.cfg.predictor( + vocab_size=self.domain.vocab_size, + output_size=len(self.mutations) + 1, + loss_fn=mut_loss_fn, + rng=rng, + model_fn=build_model_stax, + conv_depth=self.cfg.actor_conv_depth, + n_conv_layers=self.cfg.actor_n_conv_layers, + n_dense_units=self.cfg.actor_n_dense_units, + n_dense_layers=self.cfg.actor_n_dense_layers, + across_batch=self.cfg.actor_across_batch, + add_pos_encoding=self.cfg.actor_add_pos_encoding, + kernel_size=self.cfg.actor_kernel_size, + learning_rate=self.cfg.actor_learning_rate, + ) + + if self.cfg.use_actor_critic: + self._value_predictor = self.cfg.predictor( + vocab_size=self.domain.vocab_size, + output_size=1, + rng=rng, + loss_fn=value_loss_fn, + model_fn=build_model_stax, + mean_over_pos=True, + conv_depth=self.cfg.critic_conv_depth, + n_conv_layers=self.cfg.critic_n_conv_layers, + n_dense_units=self.cfg.critic_n_dense_units, + n_dense_layers=self.cfg.critic_n_dense_layers, + across_batch=self.cfg.critic_across_batch, + add_pos_encoding=self.cfg.critic_add_pos_encoding, + kernel_size=self.cfg.critic_kernel_size, + learning_rate=self.cfg.critic_learning_rate, + ) + + else: + self._value_predictor = None + else: + self._mut_predictor, self._value_predictor = self.cfg.pretrained_model + + self._data_for_grad_update = [] + self._initialized = False + + def _config(self): + cfg = super(MutationPredictorSolver, self)._config() + + cfg.update( + dict( + predictor=JaxMutationPredictor, + temperature=1., + initialize_dataset_fn=initialize_uniformly, + elite_set_size=10, + use_random_network=False, + exploit_with_best=True, + use_selection_of_best=False, + pretrained_model=None, + + # Indicator to BO to pass in previous weights. + # As implemented in cl/318101597. + warmstart=True, + + use_actor_critic=False, + num_mutations=5, + + # Hyperparameters for actor + actor_learning_rate=0.001, + actor_conv_depth=300, + actor_n_conv_layers=1, + actor_n_dense_units=100, + actor_n_dense_layers=0, + actor_kernel_size=5, + actor_across_batch=False, + actor_add_pos_encoding=True, + + # Hyperparameters for critic + critic_learning_rate=0.001, + critic_conv_depth=300, + critic_n_conv_layers=1, + critic_n_dense_units=300, + critic_n_dense_layers=0, + critic_kernel_size=5, + critic_across_batch=False, + critic_add_pos_encoding=True, + )) + return cfg + + def _get_unique(self, samples): + unique_population = data.Population() + unique_structures = set() + for sample in samples: + hashed_structure = utils.hash_structure(sample.structure) + if hashed_structure in unique_structures: + continue + unique_structures.add(hashed_structure) + unique_population.add_samples([sample]) + return unique_population + + def _get_best_samples_from_last_batch(self, + population, + n=1, + discard_duplicates=True): + best_samples = population.get_last_batch().best_n( + n, discard_duplicates=discard_duplicates) + return best_samples.structures, best_samples.rewards + + def _select(self, population): + if self.cfg.use_selection_of_best: + # Choose best samples from the previous batch + structures, rewards = self._get_best_samples_from_last_batch( + population, self.cfg.elite_set_size) + + # Choose the samples to perturb with replacement + idx = np.random.choice(len(structures), self.batch_size, replace=True) # pytype: disable=attribute-error # trace-all-classes + selected_structures = np.stack([structures[i] for i in idx]) + selected_rewards = np.stack([rewards[i] for i in idx]) + return selected_structures, selected_rewards + else: + # Just return the samples from the previous batch -- no selection + last_batch = population.get_last_batch() + structures = np.array([x.structure for x in last_batch]) + rewards = np.array([x.reward for x in last_batch]) + + if len(last_batch) > self.batch_size: # pytype: disable=attribute-error # trace-all-classes + # Subsample the data + idx = np.random.choice(len(last_batch), self.batch_size, replace=False) # pytype: disable=attribute-error # trace-all-classes + structures = np.stack([structures[i] for i in idx]) + rewards = np.stack([rewards[i] for i in idx]) + return structures, rewards + + def propose(self, num_samples, population=None, pending_samples=None): # pytype: disable=signature-mismatch # overriding-parameter-count-checks + # Initialize population randomly. + if self._initialized and population: + if num_samples != self.batch_size: + raise ValueError("Must maintain constant batch size between runs.") + counter = population.max_batch_index + if counter > 0: + if not self.cfg.use_random_network: + self._update_params(population) + else: + self.batch_size = num_samples + self._initialized = True + return self.cfg.initialize_dataset_fn( + self.domain, num_samples, random_state=self._random_state) + + # Choose best samples so far -- [elite_set_size] + samples_to_perturb, parent_rewards = self._select(population) + + perturbed, actions, mut_predictor_input = self._perturb(samples_to_perturb) + + if not self.cfg.use_random_network: + self._data_for_grad_update.append({ + "batch_index": population.current_batch_index + 1, + "mut_predictor_input": mut_predictor_input, + "actions": actions, + "parent_rewards": parent_rewards, + }) + + return np.asarray(perturbed) + + def _perturb(self, parents, mode="train"): + length = parents.shape[1] + assert length == self.domain.length + + parents_one_hot = one_hot(parents, self.domain.vocab_size) + + output = self._mut_predictor(parents_one_hot) + pos_log_probs, all_mut_log_probs = split_mutation_predictor_output(output) + + self.rng, subrng = jax.random.split(self.rng) + positions = sample_log_probs_top_k( + pos_log_probs, + subrng, + k=self.cfg.num_mutations, + temperature=self.cfg.temperature) + + pos_masks = one_hot(positions, length) + + mutation_logprobs = gather_positions(positions, all_mut_log_probs) + assert mutation_logprobs.shape == (output.shape[0], self.cfg.num_mutations, + output.shape[-1] - 1) + + self.rng, subrng = jax.random.split(self.rng) + mutation_types = gumbel_max_sampler( + mutation_logprobs, self.cfg.temperature, subrng) + + states = apply_mutations(parents, mutation_types, pos_masks, self.mutations, + use_assignment_mutations=self.use_assignment_mut) + # states shape: [num_mutations+1, batch, str_length] + # states[0] are original samples with no mutations + # states[-1] are strings with all mutations applied to them + states_oh = one_hot(states, self.domain.vocab_size) + # states_oh shape: [n_mutations+1, batch, str_length, vocab_size] + perturbed = states[-1] + + return perturbed, (mutation_types, positions), states_oh + + def _update_params(self, population): + if not self._data_for_grad_update: + return + + dat = self._data_for_grad_update.pop() + assert dat["batch_index"] == population.current_batch_index + + child_rewards = jnp.array(population.get_last_batch().rewards) + parent_rewards = dat["parent_rewards"] + + all_states = dat["mut_predictor_input"] + # all_states shape: [num_mutations, batch_size, str_length, vocab_size] + + # TODO(rubanova): rescale the rewards + terminal_rewards = child_rewards + + if self.cfg.use_actor_critic: + # Update the value function + # Compute the difference between predicted value of intermediate states + # and the final reward. + self._value_predictor.update_step( + rewards=terminal_rewards, + inputs=all_states[:-1], + actions=None, + ) + advantage = compute_advantage(self._value_predictor.params, + self._value_predictor.run_model, + terminal_rewards, all_states[:-1]) + else: + advantage = child_rewards - parent_rewards + advantage = jnp.repeat(advantage[:, None], self.cfg.num_mutations, 1) + + advantage = jax.lax.stop_gradient(advantage) + + # Perform policy update. + # Compute policy on the original samples, like in _perturb function. + self._mut_predictor.update_step( + rewards=advantage, + inputs=all_states[0], + actions=dat["actions"]) + + del all_states, advantage + + @property + def trained_model(self): + return (self._mut_predictor, self._value_predictor) \ No newline at end of file diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/domains.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/domains.py new file mode 100644 index 0000000..d271de1 --- /dev/null +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/domains.py @@ -0,0 +1,272 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Specifications for different types of input/output domains.""" + +import abc +import collections +from collections.abc import Iterable + +import gin +import numpy as np +import six +from six.moves import range + +from poli_baselines.solvers.bayesian_optimization.amortized import utils + +BOS_TOKEN = '<' # Beginning of sequence token. +EOS_TOKEN = '>' # End of sequence token. +PAD_TOKEN = '_' # End of sequence token. +MASK_TOKEN = '*' # End of sequence token. +SEP_TOKEN = '|' # A special token for separating tokens for serialization. + + +@gin.configurable +class Vocabulary(object): + """Basic vocabulary used to represent output tokens for domains.""" + + def __init__(self, + tokens, + include_bos=False, + include_eos=False, + include_pad=False, + include_mask=False, + bos_token=BOS_TOKEN, + eos_token=EOS_TOKEN, + pad_token=PAD_TOKEN, + mask_token=MASK_TOKEN): + """A token vocabulary. + + Args: + tokens: An list of tokens to put in the vocab. If an int, will be + interpreted as the number of tokens and '0', ..., 'tokens-1' will be + used as tokens. + include_bos: Whether to append `bos_token` to `tokens` that marks the + beginning of a sequence. + include_eos: Whether to append `eos_token` to `tokens` that marks the + end of a sequence. + include_pad: Whether to append `pad_token` to `tokens` to marks past end + of sequence. + include_mask: Whether to append `mask_token` to `tokens` to mark masked + positions. + bos_token: A special token than marks the beginning of sequence. + Ignored if `include_bos == False`. + eos_token: A special token than marks the end of sequence. + Ignored if `include_eos == False`. + pad_token: A special token than marks past the end of sequence. + Ignored if `include_pad == False`. + mask_token: A special token than marks MASKED positions for e.g. BERT. + Ignored if `include_mask == False`. + """ + if not isinstance(tokens, Iterable): + tokens = range(tokens) + tokens = [str(token) for token in tokens] + if include_bos: + tokens.append(bos_token) + if include_eos: + tokens.append(eos_token) + if include_pad: + tokens.append(pad_token) + if include_mask: + tokens.append(mask_token) + if len(set(tokens)) != len(tokens): + raise ValueError('tokens not unique!') + special_tokens = sorted(set(tokens) & set([SEP_TOKEN])) + if special_tokens: + raise ValueError( + f'tokens contains reserved special tokens: {special_tokens}!') + + self._tokens = tokens + self._token_ids = list(range(len(self._tokens))) + self._id_to_token = collections.OrderedDict( + zip(self._token_ids, self._tokens)) + self._token_to_id = collections.OrderedDict( + zip(self._tokens, self._token_ids)) + self._bos_token = bos_token if include_bos else None + self._eos_token = eos_token if include_eos else None + self._mask_token = mask_token if include_mask else None + self._pad_token = pad_token if include_pad else None + + def __len__(self): + return len(self._tokens) + + @property + def tokens(self): + """Return the tokens of the vocabulary.""" + return list(self._tokens) + + @property + def token_ids(self): + """Return the tokens ids of the vocabulary.""" + return list(self._token_ids) + + @property + def bos(self): + """Returns the index of the BOS token or None if unspecified.""" + return (None if self._bos_token is None else + self._token_to_id[self._bos_token]) + + @property + def eos(self): + """Returns the index of the EOS token or None if unspecified.""" + return (None if self._eos_token is None else + self._token_to_id[self._eos_token]) + + @property + def mask(self): + """Returns the index of the MASK token or None if unspecified.""" + return (None if self._mask_token is None else + self._token_to_id[self._mask_token]) + + @property + def pad(self): + """Returns the index of the PAD token or None if unspecified.""" + return (None + if self._pad_token is None else self._token_to_id[self._pad_token]) + + def is_valid(self, value): + """Tests if a value is a valid token id and returns a bool.""" + return value in self._token_ids + + def are_valid(self, values): + """Tests if values are valid token ids and returns an array of bools.""" + return np.array([self.is_valid(value) for value in values]) + + def encode(self, tokens): + """Maps an iterable of string tokens to a list of integer token ids.""" + if six.PY3 and isinstance(tokens, bytes): + # Always use Unicode in Python 3. + tokens = tokens.decode('utf-8') + return [self._token_to_id[token] for token in tokens] + + def decode(self, values, stop_at_eos=False, as_str=True): + """Maps an iterable of integer token ids to string tokens. + + Args: + values: An iterable of token ids. + stop_at_eos: Whether to ignore all values after the first EOS token id. + as_str: Whether to return a list of tokens or a concatenated string. + + Returns: + A string of tokens or a list of tokens if `as_str == False`. + """ + if stop_at_eos and self.eos is None: + raise ValueError('EOS unspecified!') + tokens = [] + for value in values: + value = int(value) # Requires if value is a scalar tensor. + if stop_at_eos and value == self.eos: + break + tokens.append(self._id_to_token[value]) + return ''.join(tokens) if as_str else tokens + + +@six.add_metaclass(abc.ABCMeta) +class Domain(object): + """Base class of problem domains, which specifies the set of valid objects.""" + + @property + def mask_fn(self): + """Returns a masking function or None.""" + + @abc.abstractmethod + def is_valid(self, sample): + """Tests if the given sample is valid for this domain.""" + + def are_valid(self, samples): + """Tests if the given samples are valid for this domain.""" + return np.array([self.is_valid(sample) for sample in samples]) + + +class DiscreteDomain(Domain): + """Base class for discrete domains: sequences of categorical variables.""" + + def __init__(self, vocab): + self._vocab = vocab + + @property + def vocab_size(self): + return len(self.vocab) + + @property + def vocab(self): + return self._vocab # pytype: disable=attribute-error # trace-all-classes + + def encode(self, samples, **kwargs): + """Maps a list of string tokens to a list of lists of integer token ids.""" + return [self.vocab.encode(sample, **kwargs) for sample in samples] + + def decode(self, samples, **kwargs): + """Maps list of lists of integer token ids to list of strings.""" + return [self.vocab.decode(sample, **kwargs) for sample in samples] + + +@gin.configurable +class FixedLengthDiscreteDomain(DiscreteDomain): + """Output is a fixed length discrete sequence.""" + + def __init__(self, vocab_size=None, length=None, vocab=None): + """Creates an instance of this class. + + Args: + vocab_size: An optional integer for constructing a vocab of this size. + If provided, `vocab` must be `None`. + length: The length of the domain (required). + vocab: The `Vocabulary` of the domain. If provided, `vocab_size` must be + `None`. + + Raises: + ValueError: If neither `vocab_size` nor `vocab` is provided. + ValueError: If `length` if not provided. + """ + if length is None: + raise ValueError('length must be provided!') + if not (vocab_size is None) ^ (vocab is None): + raise ValueError('Exactly one of vocab_size of vocab must be specified!') + self._length = length + if vocab is None: + vocab = Vocabulary(vocab_size) + super(FixedLengthDiscreteDomain, self).__init__(vocab) + + @property + def length(self): + return self._length + + @property + def size(self): + """The number of structures in the Domain.""" + return self.vocab_size**self.length + + def is_valid(self, sequence): + return len(sequence) == self.length and self.vocab.are_valid(sequence).all() + + def sample_uniformly(self, num_samples, seed=None): + random_state = utils.get_random_state(seed) + return np.int32( + random_state.randint( + size=[num_samples, self.length], low=0, high=self.vocab_size)) + + def index_to_structure(self, index): + """Given an integer and target length, encode into structure.""" + structure = np.zeros(self.length, dtype=np.int32) + tokens = [int(token, base=len(self.vocab)) + for token in np.base_repr(index, base=len(self.vocab))] + structure[-len(tokens):] = tokens + return structure + + def structure_to_index(self, structure): + """Returns the index of a sequence over a vocabulary of size `vocab_size`.""" + structure = np.asarray(structure)[::-1] + return np.sum(structure * np.power(len(self.vocab), range(len(structure)))) \ No newline at end of file diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/requirements.txt b/src/poli_baselines/solvers/bayesian_optimization/amortized/requirements.txt new file mode 100644 index 0000000..e69de29 diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/utils.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/utils.py new file mode 100644 index 0000000..22a29a7 --- /dev/null +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/utils.py @@ -0,0 +1,364 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions.""" + +import collections +import copy +import functools +import inspect +import logging as std_logging +import pprint +import random +import types +import uuid + +from absl import logging +import numpy as np +import pandas as pd +import tensorflow.compat.v1 as tf + +MIN_INT = np.iinfo(np.int64).min +MAX_INT = np.iinfo(np.int64).max +MIN_FLOAT = np.finfo(np.float32).min +MAX_FLOAT = np.finfo(np.float32).max + + +# TODO(ddohan): FrozenConfig type +class Config(dict): + """a dictionary that supports dot and dict notation. + + Create: + d = Config() + d = Config({'val1':'first'}) + + Get: + d.val2 + d['val2'] + + Set: + d.val2 = 'second' + d['val2'] = 'second' + """ + __getattr__ = dict.__getitem__ + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + def __str__(self): + return pprint.pformat(self) + + def __deepcopy__(self, memo): + return self.__class__([(copy.deepcopy(k, memo), copy.deepcopy(v, memo)) + for k, v in self.items()]) + + +def get_logger(name='', with_absl=True, level=logging.INFO): + """Creates a logger.""" + logger = std_logging.getLogger(name) + if with_absl: + logger.addHandler(logging.get_absl_handler()) + logger.propagate = False + logger.setLevel(level) + return logger + + +def get_random_state(seed_or_state): + """Returns a np.random.RandomState given an integer seed or RandomState.""" + if isinstance(seed_or_state, int): + return np.random.RandomState(seed_or_state) + elif seed_or_state is None: + # This returns the current global np random state. + return np.random.random.__self__ + elif not isinstance(seed_or_state, np.random.RandomState): + raise ValueError('Numpy RandomState or integer seed expected! Got: %s' % + seed_or_state) + else: + return seed_or_state + + +def set_seed(seed): + """Sets global Numpy, Tensorboard, and random seed.""" + np.random.seed(seed) + tf.set_random_seed(seed) + random.seed(seed, version=1) + + +def to_list(values, none_to_list=True): + """Converts `values` of any type to a `list`.""" + if (hasattr(values, '__iter__') and not isinstance(values, str) and + not isinstance(values, dict)): + return list(values) + elif none_to_list and values is None: + return [] + else: + return [values] + + +def to_array(values): + """Converts input values to a np.ndarray.""" + if tf.executing_eagerly() and tf.is_tensor(values): + return values.numpy() + else: + return np.asarray(values) + + +def arrays_from_dataset(dataset): + """Converts a tf.data.Dataset to nested np.ndarrays.""" + return tf.nest.map_structure( + lambda tensor: np.asarray(tensor), # pylint: disable=unnecessary-lambda + tensors_from_dataset(dataset)) + + +def dataset_from_tensors(tensors): + """Converts nested tf.Tensors or np.ndarrays to a tf.Data.Dataset.""" + if isinstance(tensors, types.GeneratorType) or isinstance(tensors, list): + tensors = tuple(tensors) + return tf.data.Dataset.from_tensor_slices(tensors) + + +def random_choice(values, size=None, random_state=None, **kwargs): + """Enables safer sampling from a list of values than `np.random.choice`. + + `np.random.choice` fails when trying to sample, e.g., from a list of + + indices instead of sampling from `values` directly. + + Args: + values: An iterable of values. + size: The sample size. + random_state: An integer seed or a `np.random.RandomState`. + **kwargs: Named arguments passed to `np.random.choice`. + + Returns: + As single element from `values` if `size is None`, otherwise a list of + samples from `values` of length `size`. + """ + random_state = get_random_state(random_state) + values = list(values) + effective_size = 1 if size is None else size + idxs = random_state.choice(range(len(values)), size=effective_size, **kwargs) + samples = [values[idx] for idx in idxs] + return samples[0] if size is None else samples + + +def random_shuffle(values, random_state=None): + """Shuffles a list of `values` out-of-place.""" + return random_choice( + values, size=len(values), replace=False, random_state=random_state) + + +def get_tokens(sequences, lower=False, upper=False): + """Returns a sorted list of all unique characters of a list of sequences. + + Args: + sequences: An iterable of string sequences. + lower: Whether to lower-case sequences before computing tokens. + upper: Whether to upper-case sequences before computing tokens. + + Returns: + A sorted list of all characters that appear in `sequences`. + """ + if lower and upper: + raise ValueError('lower and upper must not be specified at the same time!') + if lower: + sequences = [seq.lower() for seq in sequences] + if upper: + sequences = [seq.upper() for seq in sequences] + return sorted(set.union(*[set(seq) for seq in sequences])) + + +def tensors_from_dataset(dataset): + """Converts a tf.data.Dataset to nested tf.Tensors.""" + tensors = list(dataset) + if tensors: + return tf.nest.map_structure(lambda *tensors: tf.stack(tensors), *tensors) + # Return empty tensors if the dataset is empty. + shapes_dtypes = zip( + tf.nest.flatten(dataset.output_shapes), + tf.nest.flatten(dataset.output_types)) + tensors = [ + tf.zeros(shape=[0] + shape.as_list(), dtype=dtype) + for shape, dtype in shapes_dtypes + ] + return tf.nest.pack_sequence_as(dataset.output_shapes, tensors) + + +def hash_structure(structure): + """Hashes a structure (n-d numpy array) of either ints or floats to a string. + + Args: + structure: A structure of ints that is castable to np.int32 (examples + include an np.int32, np.int64, an int32 eager tf.Tensor, + a list of python ints, etc.) or a structure of floats that is castable + to np.float32. Here, we say that an array is castable if it can be + converted to the target type, perhaps with some loss of precision + (e.g. float64 to float32). See np.can_cast(..., casting='same_kind'). + + Returns: + A string hash for the structure. The hash will depend on the + high-level type (int vs. float), but not the precision of such a type + (int32 vs. int64). + """ + array = np.asarray(structure) + if np.can_cast(array, np.int32, 'same_kind'): + return np.int32(array).tostring() + elif np.can_cast(array, np.float32, 'same_kind'): + return np.float32(array).tostring() + raise ValueError('%s can not be safely cast to np.int32 or ' + 'np.float32' % str(structure)) + + +def create_unique_id(): + """Creates a unique hex ID.""" + return uuid.uuid1().hex + + +def deduplicate_samples(samples, select_best=False): + """De-duplicates Samples with identical structures. + + Args: + samples: An iterable of `data.Sample`s. + select_best: Whether to select the sample with the highest reward among + samples with the same structure. Otherwise, the sample that occurs + first will be selected. + + Returns: + A list of Samples. + """ + + def _sort(to_sort): + return (sorted(to_sort, key=lambda sample: sample.reward, reverse=True) + if select_best else to_sort) + + return [_sort(group)[0] for group in group_samples(samples).values()] + + +def group_samples(samples, **kwargs): + """Groups `data.Sample`s with identical structures using `group_by_hash`.""" + return group_by_hash( + samples, + hash_fn=lambda sample: hash_structure(sample.structure), + **kwargs) + + +def group_by_hash(values, hash_fn=hash, store_index=False): + """Groups values by their hash value. + + Args: + values: An iterable of any objects. + hash_fn: A function that is called to compute the hash of values. + store_index: Whether to store the index of values or values in the returned + dict. + + Returns: + A `collections.OrderedDict` mapping hashes to values with that hash if + `store_index=False`, or value indices otherwise. The length of the map + corresponds to the number of unique hashes. + """ + groups = collections.OrderedDict() + for idx, value in enumerate(values): + to_store = idx if store_index else value + groups.setdefault(hash_fn(value), []).append(to_store) + return groups + + +def get_instance(instance_or_cls, **kwargs): + """Returns an instance given an instance or class reference. + + Enables passing both class references and class instances as (gin) configs. + + Args: + instance_or_cls: An instance of class or reference to a class. + **kwargs: Names arguments used for instantiation if `instance_or_cls` is a + class. + + Returns: + An instance of a class. + """ + if (inspect.isclass(instance_or_cls) or inspect.isfunction(instance_or_cls) or + isinstance(instance_or_cls, functools.partial)): + return instance_or_cls(**kwargs) + else: + return instance_or_cls + + +def pd_option_context(width=999, + max_colwidth=999, + max_rows=200, + float_format='{:.3g}', + **kwargs): + """Returns a Pandas context manager with changed default arguments.""" + return pd.option_context('display.width', width, + 'display.max_colwidth', max_colwidth, + 'display.max_rows', max_rows, + 'display.float_format', float_format.format, + **kwargs) + + +def log_pandas(df_or_series, logger=logging.info, **kwargs): + """Logs a `pd.DataFrame` or `pd.Series`.""" + with pd_option_context(**kwargs): + for row in df_or_series.to_string().splitlines(): + logger(row) + + +def get_indices(valid_indices, + selection, + map_negative=True, + validate=False, + exclude=False): + """Maps a `selection` to `valid_indices` for indexing an iterable. + + Supports selecting indices as by + - a scalar: it[i] + - a list of scalars: it[[i, j, k]] + - a (list of) negative scalars: it[[i, -j, -k]] + - slices: it[slice(-3, None)] + + Args: + valid_indices: An iterable of valid indices that can be selected. + selection: A scalar, list of scalars, or `slice` for selecting indices. + map_negative: Whether to interpret `-i` as `len(valid_indices) + i`. + validate: Whether to raise an `IndexError` if `selection` is not + contained in `valid_indices`. + exclude: Whether to return all `valid_indices` except the selected ones. + + Raises: + IndexError: If `validate == True` and `selection` contains an index that is + not contained in `valid_indices`. + + Returns: + A list of indices. + """ + def _raise_index_error(idx): + raise IndexError(f'Index {idx} invalid! Valid indices: {valid_indices}') + + if isinstance(selection, slice): + idxs = valid_indices[selection] + else: + idxs = [] + for idx in to_list(selection): + if map_negative and isinstance(idx, int) and idx < 0: + if abs(idx) <= len(valid_indices): + idxs.append(valid_indices[idx]) + elif validate: + _raise_index_error(idx) + elif idx in valid_indices: + idxs.append(idx) + elif validate: + _raise_index_error(idx) + if exclude: + idxs = [idx for idx in valid_indices if idx not in idxs] + return idxs \ No newline at end of file diff --git a/src/poli_baselines/tests/solvers/bayesian_optimization/test_amortized_bo.py b/src/poli_baselines/tests/solvers/bayesian_optimization/test_amortized_bo.py new file mode 100644 index 0000000..091110c --- /dev/null +++ b/src/poli_baselines/tests/solvers/bayesian_optimization/test_amortized_bo.py @@ -0,0 +1,32 @@ +"""This module tests the TURBO solver.""" + +import warnings +import numpy as np + +warnings.filterwarnings("ignore") + + +def test_turbo_runs(): + from poli import objective_factory + from poli_baselines.solvers.bayesian_optimization.turbo.turbo_wrapper import ( + TurboWrapper, + ) + + problem = objective_factory.create( + name="toy_continuous_problem", + function_name="ackley_function_01", + n_dimensions=10, + ) + black_box, x0 = problem.black_box, problem.x0 + y0 = black_box(x0) + + x0 = np.random.uniform(0, 1, size=20).reshape(2, 10) + y0 = black_box(x0) + + solver = TurboWrapper(black_box, x0, y0) + + solver.solve(max_iter=5) + + +if __name__ == "__main__": + test_turbo_runs() From 30902f1db39afb529931627044b63dbeebcb2ddf Mon Sep 17 00:00:00 2001 From: Simon Bartels Date: Mon, 27 May 2024 17:06:34 +0200 Subject: [PATCH 2/9] adds requirements --- .../bayesian_optimization/amortized/requirements.txt | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/requirements.txt b/src/poli_baselines/solvers/bayesian_optimization/amortized/requirements.txt index e69de29..d8b5764 100644 --- a/src/poli_baselines/solvers/bayesian_optimization/amortized/requirements.txt +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/requirements.txt @@ -0,0 +1,6 @@ +tensorflow +gin-config >= 0.3.0 +pandas >= 1.0.5 +attrs >= 19.3.0 +jax >= 0.1.71 +jaxlib >= 0.1.48 \ No newline at end of file From 3ee1a86a4dc5904fdc6c719d2f5b63f0619fc496 Mon Sep 17 00:00:00 2001 From: Simon Bartels Date: Mon, 27 May 2024 17:06:54 +0200 Subject: [PATCH 3/9] adds a test --- .../bayesian_optimization/test_amortized_bo.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/poli_baselines/tests/solvers/bayesian_optimization/test_amortized_bo.py b/src/poli_baselines/tests/solvers/bayesian_optimization/test_amortized_bo.py index 091110c..1a1ddbf 100644 --- a/src/poli_baselines/tests/solvers/bayesian_optimization/test_amortized_bo.py +++ b/src/poli_baselines/tests/solvers/bayesian_optimization/test_amortized_bo.py @@ -3,30 +3,24 @@ import warnings import numpy as np +from poli_baselines.solvers.bayesian_optimization.amortized.amortized_bo_wrapper import AmortizedBOWrapper + warnings.filterwarnings("ignore") -def test_turbo_runs(): +def test_amortized_bo_runs(): from poli import objective_factory - from poli_baselines.solvers.bayesian_optimization.turbo.turbo_wrapper import ( - TurboWrapper, - ) problem = objective_factory.create( - name="toy_continuous_problem", - function_name="ackley_function_01", - n_dimensions=10, + name="aloha", observer_name=None ) black_box, x0 = problem.black_box, problem.x0 y0 = black_box(x0) - x0 = np.random.uniform(0, 1, size=20).reshape(2, 10) - y0 = black_box(x0) - - solver = TurboWrapper(black_box, x0, y0) + solver = AmortizedBOWrapper(black_box, x0, y0) solver.solve(max_iter=5) if __name__ == "__main__": - test_turbo_runs() + test_amortized_bo_runs() From fb1c6d5008063b18326d80f6697e6496711e7f63 Mon Sep 17 00:00:00 2001 From: Simon Bartels Date: Mon, 27 May 2024 17:07:30 +0200 Subject: [PATCH 4/9] removes imports that clash when running amortized BO from a different environment --- src/poli_baselines/solvers/__init__.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/poli_baselines/solvers/__init__.py b/src/poli_baselines/solvers/__init__.py index 2ba2660..e69de29 100644 --- a/src/poli_baselines/solvers/__init__.py +++ b/src/poli_baselines/solvers/__init__.py @@ -1,17 +0,0 @@ -from .simple.random_mutation import RandomMutation -from .simple.continuous_random_mutation import ContinuousRandomMutation -from .simple.genetic_algorithm import FixedLengthGeneticAlgorithm - -from .bayesian_optimization.vanilla_bayesian_optimization import ( - VanillaBayesianOptimization, -) -from .bayesian_optimization.line_bayesian_optimization import LineBO -from .bayesian_optimization.saas_bayesian_optimization import SAASBO -from .bayesian_optimization.baxus import BAxUS - -from .bayesian_optimization.latent_space_bayesian_optimization import ( - LatentSpaceBayesianOptimization, -) - -from .evolutionary_strategies.cma_es import CMA_ES -from .multi_objective.nsga_ii import DiscreteNSGAII From 9101befcbe2fa417f5460583eee1b0621787a32f Mon Sep 17 00:00:00 2001 From: Simon Bartels Date: Tue, 28 May 2024 14:47:22 +0200 Subject: [PATCH 5/9] fixes problem with [0, 1] bounds --- .../turbo/turbo_wrapper.py | 31 ++++++++++++++++--- .../bayesian_optimization/test_turbo.py | 6 ++-- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/src/poli_baselines/solvers/bayesian_optimization/turbo/turbo_wrapper.py b/src/poli_baselines/solvers/bayesian_optimization/turbo/turbo_wrapper.py index 34a46ac..fcc40a0 100644 --- a/src/poli_baselines/solvers/bayesian_optimization/turbo/turbo_wrapper.py +++ b/src/poli_baselines/solvers/bayesian_optimization/turbo/turbo_wrapper.py @@ -1,6 +1,7 @@ # Code taken from https://botorch.org/tutorials/turbo_1 import numpy as np from poli.core.abstract_black_box import AbstractBlackBox +from typing import List from poli_baselines.core.abstract_solver import AbstractSolver import math @@ -29,9 +30,31 @@ class TurboWrapper(AbstractSolver): - def __init__(self, black_box: AbstractBlackBox, x0: np.ndarray, y0: np.ndarray): + def __init__(self, black_box: AbstractBlackBox, x0: np.ndarray, y0: np.ndarray, bounds: np.ndarray): + """ + + Parameters + ---------- + black_box + x0 + y0 + bounds: + array of shape Dx2 where D is the dimensionality + The first row contains the lower bounds on x, the last row contains the upper bounds. + """ super().__init__(black_box, x0, y0) - self.X_turbo = torch.tensor(x0) + assert(x0.shape[0] > 1) + + assert(bounds.shape[1] == 2) + assert(bounds.shape[0] == x0.shape[1]) + assert(np.all(bounds[:, 1] >= bounds[:, 0])) + bounds[:, 1] -= bounds[:, 0] + def make_transforms(): + to_turbo = lambda X: (X - bounds[:, 0]) / bounds[:, 1] + from_turbo = lambda X: X * bounds[:, 1] + bounds[:, 0] + return to_turbo, from_turbo + self.to_turbo, self.from_turbo = make_transforms() + self.X_turbo = torch.tensor(self.to_turbo(x0)) self.Y_turbo = torch.tensor(y0) self.batch_size = 1 dim = x0.shape[1] @@ -70,14 +93,14 @@ def next_candidate(self) -> np.ndarray: raw_samples=RAW_SAMPLES, acqf="ts", ) - return X_next + return self.from_turbo(X_next.numpy()) def post_update(self, x: np.ndarray, y: np.ndarray) -> None: """ This method is called after the history is updated. """ Y_next = torch.tensor(y) - X_next = torch.tensor(x) + X_next = torch.tensor(self.to_turbo(x)) # Update state self.state = update_state(state=self.state, Y_next=Y_next) diff --git a/src/poli_baselines/tests/solvers/bayesian_optimization/test_turbo.py b/src/poli_baselines/tests/solvers/bayesian_optimization/test_turbo.py index 091110c..c422c84 100644 --- a/src/poli_baselines/tests/solvers/bayesian_optimization/test_turbo.py +++ b/src/poli_baselines/tests/solvers/bayesian_optimization/test_turbo.py @@ -18,12 +18,12 @@ def test_turbo_runs(): n_dimensions=10, ) black_box, x0 = problem.black_box, problem.x0 + x0 = np.concatenate([x0, np.random.rand(1, x0.shape[1])]) y0 = black_box(x0) - x0 = np.random.uniform(0, 1, size=20).reshape(2, 10) - y0 = black_box(x0) + bounds = np.concatenate([-np.ones([x0.shape[1], 1]), np.ones([x0.shape[1], 1])], axis=-1) - solver = TurboWrapper(black_box, x0, y0) + solver = TurboWrapper(black_box, x0, y0, bounds=bounds) solver.solve(max_iter=5) From 899999fe12729554c5ef1bd6a8e9bd7f0f57840b Mon Sep 17 00:00:00 2001 From: Simon Bartels Date: Tue, 9 Jul 2024 10:20:16 +0200 Subject: [PATCH 6/9] linted offending files --- .../amortized/amortized_bo_wrapper.py | 28 +- .../amortized/base_solver.py | 121 +- .../bayesian_optimization/amortized/data.py | 1177 ++++++------ .../amortized/deep_evolution_solver.py | 1575 +++++++++-------- .../amortized/domains.py | 443 ++--- .../bayesian_optimization/amortized/utils.py | 529 +++--- .../turbo/turbo_wrapper.py | 10 +- .../test_amortized_bo.py | 8 +- .../bayesian_optimization/test_turbo.py | 4 +- 9 files changed, 1989 insertions(+), 1906 deletions(-) diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/amortized_bo_wrapper.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/amortized_bo_wrapper.py index 646688c..df3f539 100644 --- a/src/poli_baselines/solvers/bayesian_optimization/amortized/amortized_bo_wrapper.py +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/amortized_bo_wrapper.py @@ -2,18 +2,34 @@ from poli.core.abstract_black_box import AbstractBlackBox from poli_baselines.core.abstract_solver import AbstractSolver -from poli_baselines.solvers.bayesian_optimization.amortized.data import samples_from_arrays, Population -from poli_baselines.solvers.bayesian_optimization.amortized.deep_evolution_solver import MutationPredictorSolver -from poli_baselines.solvers.bayesian_optimization.amortized.domains import FixedLengthDiscreteDomain, Vocabulary +from poli_baselines.solvers.bayesian_optimization.amortized.data import ( + samples_from_arrays, + Population, +) +from poli_baselines.solvers.bayesian_optimization.amortized.deep_evolution_solver import ( + MutationPredictorSolver, +) +from poli_baselines.solvers.bayesian_optimization.amortized.domains import ( + FixedLengthDiscreteDomain, + Vocabulary, +) class AmortizedBOWrapper(AbstractSolver): def __init__(self, black_box: AbstractBlackBox, x0: np.ndarray, y0: np.ndarray): super().__init__(black_box, x0, y0) - self.domain = FixedLengthDiscreteDomain(vocab=Vocabulary(black_box.get_black_box_info().get_alphabet()), length=x0.shape[1]) - self.solver = MutationPredictorSolver(domain=self.domain, initialize_dataset_fn=lambda *args, **kwargs: self.domain.encode(x0)) + self.domain = FixedLengthDiscreteDomain( + vocab=Vocabulary(black_box.get_black_box_info().get_alphabet()), + length=x0.shape[1], + ) + self.solver = MutationPredictorSolver( + domain=self.domain, + initialize_dataset_fn=lambda *args, **kwargs: self.domain.encode(x0), + ) def next_candidate(self) -> np.ndarray: - samples = samples_from_arrays(structures=self.domain.encode(self.x0.tolist()), rewards=self.y0.tolist()) + samples = samples_from_arrays( + structures=self.domain.encode(self.x0.tolist()), rewards=self.y0.tolist() + ) x = self.solver.propose(num_samples=1, population=Population(samples)) return np.array([c for c in self.domain.decode(x)[0]])[np.newaxis, :] diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/base_solver.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/base_solver.py index f065b35..8138056 100644 --- a/src/poli_baselines/solvers/bayesian_optimization/amortized/base_solver.py +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/base_solver.py @@ -23,67 +23,60 @@ class BaseSolver(abc.ABC): - """Solver base class.""" - - def __init__(self, - domain, - random_state=None, - name=None, - log_level=logging.INFO, - **kwargs): - """Creates an instance of this class. - - Args: - domain: An instance of a `Domain`. - random_state: An instance of or integer seed to build a - `np.random.RandomState`. - name: The name of the solver. If `None`, will use the class name. - log_level: The logging level of the solver-specific logger. - -2=ERROR, -1=WARN, 0=INFO, 1=DEBUG. - **kwargs: Named arguments stored in `self.cfg`. - """ - self._domain = domain - self._name = name or self.__class__.__name__ - self._random_state = utils.get_random_state(random_state) - self._log = utils.get_logger(self._name, level=log_level) - - cfg = utils.Config(self._config()) - cfg.update(kwargs) - self.cfg = cfg - - def _config(self): - return {} - - @property - def domain(self): - """Return the optimization domain.""" - return self._domain - - @property - def name(self): - """Returns the solver name.""" - return self._name - - def __str__(self): - return self._name - - @abc.abstractmethod - def propose(self, - num_samples, - population=None, - pending_samples=None, - counter=0): - """Proposes num_samples from `self.domain`. - - Args: - num_samples: The number of samples to return. - population: A `Population` of samples or None if the population is empty. - pending_samples: A list of structures without reward that were already - proposed. - counter: The number of times `propose` has been called with the same - `population`. Can be used by solver to avoid repeated computations on - the same `population`, e.g. updating a model. - - Returns: - `num_samples` structures from the domain. - """ \ No newline at end of file + """Solver base class.""" + + def __init__( + self, domain, random_state=None, name=None, log_level=logging.INFO, **kwargs + ): + """Creates an instance of this class. + + Args: + domain: An instance of a `Domain`. + random_state: An instance of or integer seed to build a + `np.random.RandomState`. + name: The name of the solver. If `None`, will use the class name. + log_level: The logging level of the solver-specific logger. + -2=ERROR, -1=WARN, 0=INFO, 1=DEBUG. + **kwargs: Named arguments stored in `self.cfg`. + """ + self._domain = domain + self._name = name or self.__class__.__name__ + self._random_state = utils.get_random_state(random_state) + self._log = utils.get_logger(self._name, level=log_level) + + cfg = utils.Config(self._config()) + cfg.update(kwargs) + self.cfg = cfg + + def _config(self): + return {} + + @property + def domain(self): + """Return the optimization domain.""" + return self._domain + + @property + def name(self): + """Returns the solver name.""" + return self._name + + def __str__(self): + return self._name + + @abc.abstractmethod + def propose(self, num_samples, population=None, pending_samples=None, counter=0): + """Proposes num_samples from `self.domain`. + + Args: + num_samples: The number of samples to return. + population: A `Population` of samples or None if the population is empty. + pending_samples: A list of structures without reward that were already + proposed. + counter: The number of times `propose` has been called with the same + `population`. Can be used by solver to avoid repeated computations on + the same `population`, e.g. updating a model. + + Returns: + `num_samples` structures from the domain. + """ diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/data.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/data.py index f98d72f..2b18658 100644 --- a/src/poli_baselines/solvers/bayesian_optimization/amortized/data.py +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/data.py @@ -30,469 +30,262 @@ # DatasetSample defines the structure of samples in tf.data.Datasets for # pre-training solvers, whereas Samples (see below) defines the structure of # samples in a Population object. -DatasetSample = collections.namedtuple('DatasetSample', ['structure', 'reward']) +DatasetSample = collections.namedtuple("DatasetSample", ["structure", "reward"]) def dataset_to_population(population_or_tf_dataset): - """Converts a TF dataset to a Population if it is not already a Population.""" - if isinstance(population_or_tf_dataset, Population): - return population_or_tf_dataset - else: - return Population.from_dataset(population_or_tf_dataset) + """Converts a TF dataset to a Population if it is not already a Population.""" + if isinstance(population_or_tf_dataset, Population): + return population_or_tf_dataset + else: + return Population.from_dataset(population_or_tf_dataset) def serialize_structure(structure): - """Converts a structure to a string.""" - structure = np.asarray(structure) - dim = len(structure.shape) - if dim != 1: - raise NotImplementedError(f'`structure` must be 1d but is {dim}d!') - return domains.SEP_TOKEN.join(str(token) for token in structure) + """Converts a structure to a string.""" + structure = np.asarray(structure) + dim = len(structure.shape) + if dim != 1: + raise NotImplementedError(f"`structure` must be 1d but is {dim}d!") + return domains.SEP_TOKEN.join(str(token) for token in structure) def serialize_structures(structures, **kwargs): - """Converts a list of structures to a list of strings.""" - return [serialize_structure(structure, **kwargs) - for structure in structures] + """Converts a list of structures to a list of strings.""" + return [serialize_structure(structure, **kwargs) for structure in structures] def deserialize_structure(serialized_structure, dtype=np.int32): - """Converts a string to a structure. + """Converts a string to a structure. - Args: - serialized_structure: A structure produced by `serialize_structure`. - dtype: The data type of the output numpy array. + Args: + serialized_structure: A structure produced by `serialize_structure`. + dtype: The data type of the output numpy array. - Returns: - A numpy array with `dtype`. - """ - return np.asarray( - [token for token in serialized_structure.split(domains.SEP_TOKEN)], - dtype=dtype) + Returns: + A numpy array with `dtype`. + """ + return np.asarray( + [token for token in serialized_structure.split(domains.SEP_TOKEN)], dtype=dtype + ) def deserialize_structures(structures, **kwargs): - """Converts a list of strings to a list of structures. + """Converts a list of strings to a list of structures. - Args: - structures: A list of strings produced by `serialize_structures`. - **kwargs: Named arguments passed to `deserialize_structure`. + Args: + structures: A list of strings produced by `serialize_structures`. + **kwargs: Named arguments passed to `deserialize_structure`. - Returns: - A list of numpy array. - """ - return [deserialize_structure(structure, **kwargs) - for structure in structures] + Returns: + A list of numpy array. + """ + return [deserialize_structure(structure, **kwargs) for structure in structures] def serialize_population_frame(frame, inplace=False, domain=None): - """Serializes a population `pd.DataFrame` for representing it as plain text. - - Args: - frame: A `pd.DataFrame` produced by `Population.to_frame`. - inplace: Whether to serialize `frame` inplace instead of creating a copy. - domain: An optional domain for decoding structures. If provided, will - add a column `decoded_structure` with the serialized decoded structures. - - Returns: - A `pd.DataFrame` with serialized structures. - """ - if not inplace: - frame = frame.copy() - if domain: - frame['decoded_structure'] = serialize_structures( - domain.decode(frame['structure'], as_str=False)) - frame['structure'] = serialize_structures(frame['structure']) - return frame - - -def deserialize_population_frame(frame, inplace=False): - """Deserializes a population `pd.DataFrame` from plain text. - - Args: - frame: A `pd.DataFrame` produced by `serialize_population_frame`. - inplace: Whether to deserialize `frame` inplace instead of creating a copy. - - Returns: - A `pd.DataFrame` with deserialized structures. - """ - if not inplace: - frame = frame.copy() - frame['structure'] = deserialize_structures(frame['structure']) - if 'decoded_structure' in frame.columns: - frame['decoded_structure'] = deserialize_structures( - frame['decoded_structure'], dtype=str) - return frame - - -def population_frame_to_csv(frame, path_or_buf=None, domain=None, index=False, - **kwargs): - """Converts a population `pd.DataFrame` to a csv table. - - Args: - frame: A `pd.DataFrame` produced by `Population.to_frame`. - path_or_buf: File path or object. If `None`, the result is returned as a - string. Otherwise write the csv table to that file. - domain: A optional domain for decoding structures. - index: Whether to store the index of `frame`. - **kwargs: Named arguments passed to `frame.to_csv`. - - Returns: - If `path_or_buf` is `None`, returns the resulting csv format as a - string. Otherwise returns `None`. - """ - if frame.empty: - raise ValueError('Cannot write empty population frame to CSV file!') - frame = serialize_population_frame(frame, domain=domain) - return frame.to_csv(path_or_buf, index=index, **kwargs) - - -def population_frame_from_csv(path_or_buf, **kwargs): - """Reads a population `pd.DataFrame` from a file. - - Args: - path_or_buf: A string path of file buffer. - **kwargs: Named arguments passed to `pd.read_csv`. - - Returns: - A `pd.DataFrame`. - """ - frame = pd.read_csv(path_or_buf, dtype={'metadata': object}, **kwargs) - frame = deserialize_population_frame(frame) - return frame - - -def subtract_mean_batch_reward(population): - """Returns new Population where each batch has mean-zero rewards.""" - df = population.to_frame() - mean_dict = df.groupby('batch_index').reward.mean().to_dict() - - def reward_for_sample(sample): - return sample.reward - mean_dict[sample.batch_index] - - shifted_samples = [ - sample.copy(reward=reward_for_sample(sample)) for sample in population - ] - return Population(shifted_samples) - - -def _to_immutable_array(array): - to_return = np.array(array) - to_return.setflags(write=False) - return to_return + """Serializes a population `pd.DataFrame` for representing it as plain text. + Args: + frame: A `pd.DataFrame` produced by `Population.to_frame`. + inplace: Whether to serialize `frame` inplace instead of creating a copy. + domain: An optional domain for decoding structures. If provided, will + add a column `decoded_structure` with the serialized decoded structures. -class _NumericConverter(object): - """Helper class for converting values to a numeric data type.""" - - def __init__(self, dtype, min_value=None, max_value=None): - self._dtype = dtype - self._min_value = min_value - self._max_value = max_value - - def __call__(self, value): - """Validates and converts `value` to `self._dtype`.""" - if value is None: - return value - if not np.isreal(value): - raise TypeError('%s is not numeric!' % value) - value = self._dtype(value) - if self._min_value is not None and value < self._min_value: - raise TypeError('%f < %f' % (value, self._min_value)) - if self._max_value is not None and value > self._max_value: - raise TypeError('%f > %f' % (value, self._max_value)) - return value + Returns: + A `pd.DataFrame` with serialized structures. + """ + if not inplace: + frame = frame.copy() + if domain: + frame["decoded_structure"] = serialize_structures( + domain.decode(frame["structure"], as_str=False) + ) + frame["structure"] = serialize_structures(frame["structure"]) + return frame -@attr.s( - frozen=True, # Make it immutable. - slots=True, # Improve memory overhead. - eq=False, # Because we override __eq__. -) -class Sample(object): - """Immutable container for a structure, reward, and additional data. - - Attributes: - key: (str) A unique identifier of the sample. If not provided, will create - a unique identifier using `utils.create_unique_id`. - structure: (np.ndarray) The structure. - reward: (float, optional) The reward. - batch_index: (int, optional) The batch index within the population. - infeasible: Whether the sample was marked as infeasible by the evaluator. - metadata: (any type, optional) Additional meta-data. - """ - structure: np.ndarray = attr.ib(factory=np.array, - converter=_to_immutable_array) # pytype: disable=wrong-arg-types # attr-stubs - reward: float = attr.ib( - converter=_NumericConverter(float), - default=None) - batch_index: int = attr.ib( - converter=_NumericConverter(int, min_value=0), - default=None) - infeasible: bool = attr.ib(default=False, - validator=attr.validators.instance_of(bool)) - key: str = attr.ib( - factory=utils.create_unique_id, - validator=attr.validators.optional(attr.validators.instance_of(str))) - metadata: typing.Dict[str, typing.Any] = attr.ib(default=None) - - def __eq__(self, other): - """Compares samples irrespective of their key.""" - return (self.reward == other.reward and - self.batch_index == other.batch_index and - self.infeasible == other.infeasible and - self.metadata == other.metadata and - np.array_equal(self.structure, other.structure)) - - def equal(self, other): - """Compares samples including their key.""" - return self == other and self.key == other.key - - def to_tfexample(self): - """Converts a Sample to a tf.Example.""" - features = dict( - structure=tf.train.Feature( - int64_list=tf.train.Int64List(value=self.structure)), - reward=tf.train.Feature( - float_list=tf.train.FloatList(value=[self.reward])), - batch_index=tf.train.Feature( - int64_list=tf.train.Int64List(value=[self.batch_index]))) - return tf.train.Example(features=tf.train.Features(feature=features)) - - def to_dict(self, dict_factory=collections.OrderedDict): - """Returns a dictionary from field names to values. +def deserialize_population_frame(frame, inplace=False): + """Deserializes a population `pd.DataFrame` from plain text. Args: - dict_factory: A class that implements a dict factory method. + frame: A `pd.DataFrame` produced by `serialize_population_frame`. + inplace: Whether to deserialize `frame` inplace instead of creating a copy. Returns: - A dict of type `dict_factory` + A `pd.DataFrame` with deserialized structures. """ - return attr.asdict(self, dict_factory=dict_factory) - - def copy(self, new_key=True, **kwargs): - """Returns a copy of this Sample with values overridden by **kwargs.""" - if new_key: - kwargs['key'] = utils.create_unique_id() - return attr.evolve(self, **kwargs) - - -def samples_from_arrays(structures, rewards=None, batch_index=None, - metadata=None): - """Makes a generator of Samples from fields. - - Args: - structures: Iterable of structures (1-D np array or list). - rewards: Iterable of float rewards. If None, the corresponding Samples are - given each given a reward of None. - batch_index: Either an int, in which case all Samples created by this - function will be given this batch_index or an iterable of ints for each - corresponding structure. - metadata: Metadata to store in the Sample. - - Yields: - A generator of Samples - """ - structures = utils.to_array(structures) - - if metadata is None: - metadata = [None] * len(structures) - - if rewards is None: - rewards = [None] * len(structures) - else: - rewards = utils.to_array(rewards) - - if len(structures) != len(rewards): - raise ValueError( - 'Structures and rewards must be same length. Are %s and %s' % - (len(structures), len(rewards))) - if len(metadata) != len(rewards): - raise ValueError('Metadata and rewards must be same length. Are %s and %s' % - (len(metadata), len(rewards))) - - if batch_index is None: - batch_index = 0 - if isinstance(batch_index, int): - batch_index = [batch_index] * len(structures) - - for structure, reward, batch_index, meta in zip(structures, rewards, - batch_index, metadata): - yield Sample( - structure=structure, - reward=reward, - batch_index=batch_index, - metadata=meta) - + if not inplace: + frame = frame.copy() + frame["structure"] = deserialize_structures(frame["structure"]) + if "decoded_structure" in frame.columns: + frame["decoded_structure"] = deserialize_structures( + frame["decoded_structure"], dtype=str + ) + return frame -def parse_tf_example(example_proto): - """Converts tf.Example proto to dict of Tensors. - Args: - example_proto: A raw tf.Example proto. - Returns: - A dict of Tensors with fields structure, reward, and batch_index. - """ +def population_frame_to_csv( + frame, path_or_buf=None, domain=None, index=False, **kwargs +): + """Converts a population `pd.DataFrame` to a csv table. - feature_description = dict( - structure=tf.FixedLenSequenceFeature((), tf.int64, allow_missing=True), - reward=tf.FixedLenFeature([1], tf.float32), - batch_index=tf.FixedLenFeature([1], tf.int64)) - - return tf.io.parse_single_example( - serialized=example_proto, features=feature_description) + Args: + frame: A `pd.DataFrame` produced by `Population.to_frame`. + path_or_buf: File path or object. If `None`, the result is returned as a + string. Otherwise write the csv table to that file. + domain: A optional domain for decoding structures. + index: Whether to store the index of `frame`. + **kwargs: Named arguments passed to `frame.to_csv`. + Returns: + If `path_or_buf` is `None`, returns the resulting csv format as a + string. Otherwise returns `None`. + """ + if frame.empty: + raise ValueError("Cannot write empty population frame to CSV file!") + frame = serialize_population_frame(frame, domain=domain) + return frame.to_csv(path_or_buf, index=index, **kwargs) -class Population(object): - """Data structure for storing Samples.""" - def __init__(self, samples=None): - """Construct a Population. +def population_frame_from_csv(path_or_buf, **kwargs): + """Reads a population `pd.DataFrame` from a file. Args: - samples: An iterable of Samples + path_or_buf: A string path of file buffer. + **kwargs: Named arguments passed to `pd.read_csv`. + + Returns: + A `pd.DataFrame`. """ - self._samples = collections.OrderedDict() - self._batch_to_sample_keys = collections.defaultdict(list) + frame = pd.read_csv(path_or_buf, dtype={"metadata": object}, **kwargs) + frame = deserialize_population_frame(frame) + return frame - if samples is not None: - self.add_samples(samples) - def __str__(self): - if self.empty: - return '' - return ('' % - (len(self), dict(self.best().to_dict()))) +def subtract_mean_batch_reward(population): + """Returns new Population where each batch has mean-zero rewards.""" + df = population.to_frame() + mean_dict = df.groupby("batch_index").reward.mean().to_dict() - def __len__(self): - return len(self._samples) + def reward_for_sample(sample): + return sample.reward - mean_dict[sample.batch_index] - def __iter__(self): - return self._samples.values().__iter__() + shifted_samples = [ + sample.copy(reward=reward_for_sample(sample)) for sample in population + ] + return Population(shifted_samples) - def __eq__(self, other): - if not isinstance(other, Population): - raise ValueError('Cannot compare equality with an object of ' - 'type %s' % (str(type(other)))) - return len(self) == len(other) and all( - s1 == s2 for s1, s2 in zip(self, other)) +def _to_immutable_array(array): + to_return = np.array(array) + to_return.setflags(write=False) + return to_return - def __add__(self, other): - """Adds samples to this population and returns a new Population.""" - return Population(itertools.chain(self, other)) - def __getitem__(self, key_or_index): - if isinstance(key_or_index, str): - return self._samples[key_or_index] - else: - return list(self._samples.values())[key_or_index] - - def __contains__(self, key): - if isinstance(key, Sample): - key = key.key - return key in self._samples - - def copy(self): - """Copies the population.""" - return Population(self.samples) - - @property - def samples(self): - """Returns the population Samples as a list.""" - return list(self._samples.values()) - - def add_sample(self, sample): - """Add copy of sample to population.""" - if sample.key in self._samples: - raise ValueError('Sample with key %s already exists in the population!') - self._samples[sample.key] = sample - batch_idx = sample.batch_index - self._batch_to_sample_keys[batch_idx].append(sample.key) - - def add_samples(self, samples): - """Convenience method for adding multiple samples.""" - for sample in samples: - self.add_sample(sample) - - @property - def empty(self): - return not self - - @property - def batch_indices(self): - """Returns a sorted list of unique batch indices.""" - return sorted(self._batch_to_sample_keys) - - @property - def max_batch_index(self): - """Returns the maximum batch index.""" - if self.empty: - raise ValueError('Population empty!') - return self.batch_indices[-1] - - @property - def current_batch_index(self): - """Return the maximum batch index or -1 if the population is empty.""" - return -1 if self.empty else self.max_batch_index - - def get_batches(self, batch_indices, exclude=False, validate=True): - """"Extracts certain batches from the population. - - Ignores batches that do not exist in the population. To validate if a - certain batch exists use `batch_index in population.batch_indices`. +class _NumericConverter(object): + """Helper class for converting values to a numeric data type.""" + + def __init__(self, dtype, min_value=None, max_value=None): + self._dtype = dtype + self._min_value = min_value + self._max_value = max_value + + def __call__(self, value): + """Validates and converts `value` to `self._dtype`.""" + if value is None: + return value + if not np.isreal(value): + raise TypeError("%s is not numeric!" % value) + value = self._dtype(value) + if self._min_value is not None and value < self._min_value: + raise TypeError("%f < %f" % (value, self._min_value)) + if self._max_value is not None and value > self._max_value: + raise TypeError("%f > %f" % (value, self._max_value)) + return value - Args: - batch_indices: An integer, iterable of integers, or `slice` object - for selecting batches. - exclude: If true, will return all batches but the ones that are selected. - validate: Whether to raise an exception if a batch index is invalid - instead of ignoring. - Returns: - A `Population` with the selected batches. +@attr.s( + frozen=True, # Make it immutable. + slots=True, # Improve memory overhead. + eq=False, # Because we override __eq__. +) +class Sample(object): + """Immutable container for a structure, reward, and additional data. + + Attributes: + key: (str) A unique identifier of the sample. If not provided, will create + a unique identifier using `utils.create_unique_id`. + structure: (np.ndarray) The structure. + reward: (float, optional) The reward. + batch_index: (int, optional) The batch index within the population. + infeasible: Whether the sample was marked as infeasible by the evaluator. + metadata: (any type, optional) Additional meta-data. """ - batch_indices = utils.get_indices( - self.batch_indices, batch_indices, exclude=exclude, validate=validate) - sample_idxs = [] - for batch_index in batch_indices: - sample_idxs.extend(self._batch_to_sample_keys.get(batch_index, [])) - samples = [self[idx] for idx in sample_idxs] - return Population(samples) - - def get_batch(self, *args, **kwargs): - return self.get_batches(*args, **kwargs) - - def get_last_batch(self, **kwargs): - """Returns the last batch from the population.""" - return self.get_batch(-1, **kwargs) - - def get_last_batches(self, n=1, **kwargs): - """Selects the last n batches.""" - return self.get_batches(self.batch_indices[-n:], **kwargs) - - def to_structures_and_rewards(self): - """Return (list of structures, list of rewards) in the Population.""" - structures = [] - rewards = [] - for sample in self: - structures.append(sample.structure) - rewards.append(sample.reward) - return structures, rewards - - @property - def structures(self): - """Returns the structure of all samples in the population.""" - return [sample.structure for sample in self] - - @property - def rewards(self): - """Returns the reward of all samples in the population.""" - return [sample.reward for sample in self] - - @staticmethod - def from_arrays(structures, rewards=None, batch_index=0, metadata=None): - """Creates Population from the specified fields. + + structure: np.ndarray = attr.ib( + factory=np.array, converter=_to_immutable_array + ) # pytype: disable=wrong-arg-types # attr-stubs + reward: float = attr.ib(converter=_NumericConverter(float), default=None) + batch_index: int = attr.ib( + converter=_NumericConverter(int, min_value=0), default=None + ) + infeasible: bool = attr.ib( + default=False, validator=attr.validators.instance_of(bool) + ) + key: str = attr.ib( + factory=utils.create_unique_id, + validator=attr.validators.optional(attr.validators.instance_of(str)), + ) + metadata: typing.Dict[str, typing.Any] = attr.ib(default=None) + + def __eq__(self, other): + """Compares samples irrespective of their key.""" + return ( + self.reward == other.reward + and self.batch_index == other.batch_index + and self.infeasible == other.infeasible + and self.metadata == other.metadata + and np.array_equal(self.structure, other.structure) + ) + + def equal(self, other): + """Compares samples including their key.""" + return self == other and self.key == other.key + + def to_tfexample(self): + """Converts a Sample to a tf.Example.""" + features = dict( + structure=tf.train.Feature( + int64_list=tf.train.Int64List(value=self.structure) + ), + reward=tf.train.Feature(float_list=tf.train.FloatList(value=[self.reward])), + batch_index=tf.train.Feature( + int64_list=tf.train.Int64List(value=[self.batch_index]) + ), + ) + return tf.train.Example(features=tf.train.Features(feature=features)) + + def to_dict(self, dict_factory=collections.OrderedDict): + """Returns a dictionary from field names to values. + + Args: + dict_factory: A class that implements a dict factory method. + + Returns: + A dict of type `dict_factory` + """ + return attr.asdict(self, dict_factory=dict_factory) + + def copy(self, new_key=True, **kwargs): + """Returns a copy of this Sample with values overridden by **kwargs.""" + if new_key: + kwargs["key"] = utils.create_unique_id() + return attr.evolve(self, **kwargs) + + +def samples_from_arrays(structures, rewards=None, batch_index=None, metadata=None): + """Makes a generator of Samples from fields. Args: structures: Iterable of structures (1-D np array or list). @@ -501,185 +294,413 @@ def from_arrays(structures, rewards=None, batch_index=0, metadata=None): batch_index: Either an int, in which case all Samples created by this function will be given this batch_index or an iterable of ints for each corresponding structure. - metadata: Metadata to store in the Samples. + metadata: Metadata to store in the Sample. - Returns: - A Population. - """ - samples = samples_from_arrays(structures, rewards, batch_index, metadata) - return Population(samples) - - def add_batch(self, structures, rewards=None, batch_index=None, - metadata=None): - """Adds a batch of samples to the Population. - - Args: - structures: Iterable of structures (1-D np array or list). - rewards: Iterable of rewards. - batch_index: Either an int, in which case all Samples created by this - function will be given this batch_index or an iterable of ints for each - corresponding structure. If `None`, uses `self.current_batch + 1`. - metadata: Metadata to store in the Samples. + Yields: + A generator of Samples """ - if batch_index is None: - batch_index = self.current_batch_index + 1 - samples = samples_from_arrays(structures, rewards, batch_index, metadata) - self.add_samples(samples) - - def head(self, n): - """Returns new Population containing first n samples.""" - - return Population(self.samples[:n]) - - @staticmethod - def from_dataset(dataset): - """Converts dataset of DatasetSample to Population.""" - samples = utils.arrays_from_dataset(dataset) - return Population.from_arrays(samples.structure, samples.reward) - - def to_dataset(self): - """Converts the population to a `tf.data.Dataset` with `DatasetSample`s.""" - structures, rewards = self.to_structures_and_rewards() - return utils.dataset_from_tensors( - DatasetSample(structure=structures, reward=rewards)) - - @staticmethod - def from_tfrecord(filename): - """Reads Population from tfrecord file.""" - raw_dataset = tf.data.TFRecordDataset([filename]) - parsed_dataset = raw_dataset.map(parse_tf_example) - - def _record_to_dict(record): - mapping = {key: utils.to_array(value) for key, value in record.items()} - if 'batch_index' in mapping: - mapping['batch_index'] = int(mapping['batch_index']) - return mapping - - return Population(Sample(**_record_to_dict(r)) for r in parsed_dataset) - - def to_tfrecord(self, filename): - """Writes Population to tfrecord file.""" - - with tf.python_io.TFRecordWriter(filename) as writer: - for sample in self.samples: - writer.write(sample.to_tfexample().SerializeToString()) - - def to_frame(self): - """Converts a `Population` to a `pd.DataFrame`.""" - records = [sample.to_dict() for sample in self.samples] - return pd.DataFrame.from_records(records) - - @classmethod - def from_frame(cls, frame): - """Converts a `pd.DataFrame` to a `Population`.""" - if frame.empty: - return cls() - population = cls() - for _, row in frame.iterrows(): - sample = Sample(**row.to_dict()) - population.add_sample(sample) - return population + structures = utils.to_array(structures) - def to_csv(self, path, domain=None): - """Stores a population to a CSV file. + if metadata is None: + metadata = [None] * len(structures) - Args: - path: The output file path. - domain: An optional `domains.Domain`. If provided, will also store - decoded structures in the CSV file. - """ - population_frame_to_csv(self.to_frame(), path, domain=domain) + if rewards is None: + rewards = [None] * len(structures) + else: + rewards = utils.to_array(rewards) + + if len(structures) != len(rewards): + raise ValueError( + "Structures and rewards must be same length. Are %s and %s" + % (len(structures), len(rewards)) + ) + if len(metadata) != len(rewards): + raise ValueError( + "Metadata and rewards must be same length. Are %s and %s" + % (len(metadata), len(rewards)) + ) - @classmethod - def from_csv(cls, path): - """Restores a population from a CSV file. + if batch_index is None: + batch_index = 0 + if isinstance(batch_index, int): + batch_index = [batch_index] * len(structures) - Args: - path: The CSV file path. + for structure, reward, batch_index, meta in zip( + structures, rewards, batch_index, metadata + ): + yield Sample( + structure=structure, reward=reward, batch_index=batch_index, metadata=meta + ) - Returns: - An instance of a `Population`. - """ - return cls.from_frame(population_frame_from_csv(path).drop( - columns=['decoded_structure'], errors='ignore')) - - def best_n(self, n=1, q=None, discard_duplicates=False, blacklist=None): - """Returns the best n samples. - Note that ties are broken deterministically. +def parse_tf_example(example_proto): + """Converts tf.Example proto to dict of Tensors. Args: - n: Max number to return - q: A float in (0, 1) corresponding to the minimum quantile for selecting - samples. If provided, `n` is ignored and samples with a reward >= - this quantile are selected. - discard_duplicates: If True, when several samples have the same structure, - return only one of them (the selected one is unspecified). - blacklist: Iterable of structures that should be excluded. - + example_proto: A raw tf.Example proto. Returns: - Population containing the best n Samples, sorted in decreasing order of - reward (output[0] is the best). Returns less than n if there are fewer - than n Samples in the population. + A dict of Tensors with fields structure, reward, and batch_index. """ - if self.empty: - raise ValueError('Population empty.') - - samples = self.samples - if blacklist: - samples = self._filter(samples, blacklist) - - # are unique. - if discard_duplicates and len(samples) > 1: - samples = utils.deduplicate_samples(samples) - - samples = sorted(samples, key=lambda sample: sample.reward, reverse=True) - if q is not None: - q_value = np.quantile([sample.reward for sample in samples], q) - return Population( - [sample for sample in samples if sample.reward >= q_value]) - else: - return Population(samples[:n]) - - def best(self, blacklist=None): - return self.best_n(1, blacklist=blacklist).samples[0] - - def _filter(self, samples, blacklist): - blacklist = set(utils.hash_structure(structure) for structure in blacklist) - return Population( - sample for sample in samples - if utils.hash_structure(sample.structure) not in blacklist) - - def contains_structure(self, structure): - return self.contains_structures([structure])[0] - - # TODO(dbelanger): this has computational overhead because it hashes the - # entire population. If this function is being called often, the set of hashed - # structures should be built up incrementally as elements are added to the - # population. Right now, the function is rarely called. - # TODO(ddohan): Build this just in time: track what hasn't been hashed - # since last call, and add any new structures before checking contain. - def contains_structures(self, structures): - structures_in_population = set( - utils.hash_structure(sample.structure) for sample in self.samples) - return [ - utils.hash_structure(structure) in structures_in_population - for structure in structures - ] - def deduplicate(self, select_best=False): - """De-duplicates Samples with identical structures. + feature_description = dict( + structure=tf.FixedLenSequenceFeature((), tf.int64, allow_missing=True), + reward=tf.FixedLenFeature([1], tf.float32), + batch_index=tf.FixedLenFeature([1], tf.int64), + ) - Args: - select_best: Whether to select the sample with the highest reward among - samples with the same structure. Otherwise, the sample that occurs - first will be selected. + return tf.io.parse_single_example( + serialized=example_proto, features=feature_description + ) - Returns: - A Population with de-duplicated samples. - """ - return Population(utils.deduplicate_samples(self, select_best=select_best)) - def discard_infeasible(self): - """Returns a new `Population` with all infeasible samples removed.""" - return Population(sample for sample in self if not sample.infeasible) \ No newline at end of file +class Population(object): + """Data structure for storing Samples.""" + + def __init__(self, samples=None): + """Construct a Population. + + Args: + samples: An iterable of Samples + """ + self._samples = collections.OrderedDict() + self._batch_to_sample_keys = collections.defaultdict(list) + + if samples is not None: + self.add_samples(samples) + + def __str__(self): + if self.empty: + return "" + return "" % ( + len(self), + dict(self.best().to_dict()), + ) + + def __len__(self): + return len(self._samples) + + def __iter__(self): + return self._samples.values().__iter__() + + def __eq__(self, other): + if not isinstance(other, Population): + raise ValueError( + "Cannot compare equality with an object of " + "type %s" % (str(type(other))) + ) + + return len(self) == len(other) and all(s1 == s2 for s1, s2 in zip(self, other)) + + def __add__(self, other): + """Adds samples to this population and returns a new Population.""" + return Population(itertools.chain(self, other)) + + def __getitem__(self, key_or_index): + if isinstance(key_or_index, str): + return self._samples[key_or_index] + else: + return list(self._samples.values())[key_or_index] + + def __contains__(self, key): + if isinstance(key, Sample): + key = key.key + return key in self._samples + + def copy(self): + """Copies the population.""" + return Population(self.samples) + + @property + def samples(self): + """Returns the population Samples as a list.""" + return list(self._samples.values()) + + def add_sample(self, sample): + """Add copy of sample to population.""" + if sample.key in self._samples: + raise ValueError("Sample with key %s already exists in the population!") + self._samples[sample.key] = sample + batch_idx = sample.batch_index + self._batch_to_sample_keys[batch_idx].append(sample.key) + + def add_samples(self, samples): + """Convenience method for adding multiple samples.""" + for sample in samples: + self.add_sample(sample) + + @property + def empty(self): + return not self + + @property + def batch_indices(self): + """Returns a sorted list of unique batch indices.""" + return sorted(self._batch_to_sample_keys) + + @property + def max_batch_index(self): + """Returns the maximum batch index.""" + if self.empty: + raise ValueError("Population empty!") + return self.batch_indices[-1] + + @property + def current_batch_index(self): + """Return the maximum batch index or -1 if the population is empty.""" + return -1 if self.empty else self.max_batch_index + + def get_batches(self, batch_indices, exclude=False, validate=True): + """ "Extracts certain batches from the population. + + Ignores batches that do not exist in the population. To validate if a + certain batch exists use `batch_index in population.batch_indices`. + + Args: + batch_indices: An integer, iterable of integers, or `slice` object + for selecting batches. + exclude: If true, will return all batches but the ones that are selected. + validate: Whether to raise an exception if a batch index is invalid + instead of ignoring. + + Returns: + A `Population` with the selected batches. + """ + batch_indices = utils.get_indices( + self.batch_indices, batch_indices, exclude=exclude, validate=validate + ) + sample_idxs = [] + for batch_index in batch_indices: + sample_idxs.extend(self._batch_to_sample_keys.get(batch_index, [])) + samples = [self[idx] for idx in sample_idxs] + return Population(samples) + + def get_batch(self, *args, **kwargs): + return self.get_batches(*args, **kwargs) + + def get_last_batch(self, **kwargs): + """Returns the last batch from the population.""" + return self.get_batch(-1, **kwargs) + + def get_last_batches(self, n=1, **kwargs): + """Selects the last n batches.""" + return self.get_batches(self.batch_indices[-n:], **kwargs) + + def to_structures_and_rewards(self): + """Return (list of structures, list of rewards) in the Population.""" + structures = [] + rewards = [] + for sample in self: + structures.append(sample.structure) + rewards.append(sample.reward) + return structures, rewards + + @property + def structures(self): + """Returns the structure of all samples in the population.""" + return [sample.structure for sample in self] + + @property + def rewards(self): + """Returns the reward of all samples in the population.""" + return [sample.reward for sample in self] + + @staticmethod + def from_arrays(structures, rewards=None, batch_index=0, metadata=None): + """Creates Population from the specified fields. + + Args: + structures: Iterable of structures (1-D np array or list). + rewards: Iterable of float rewards. If None, the corresponding Samples are + given each given a reward of None. + batch_index: Either an int, in which case all Samples created by this + function will be given this batch_index or an iterable of ints for each + corresponding structure. + metadata: Metadata to store in the Samples. + + Returns: + A Population. + """ + samples = samples_from_arrays(structures, rewards, batch_index, metadata) + return Population(samples) + + def add_batch(self, structures, rewards=None, batch_index=None, metadata=None): + """Adds a batch of samples to the Population. + + Args: + structures: Iterable of structures (1-D np array or list). + rewards: Iterable of rewards. + batch_index: Either an int, in which case all Samples created by this + function will be given this batch_index or an iterable of ints for each + corresponding structure. If `None`, uses `self.current_batch + 1`. + metadata: Metadata to store in the Samples. + """ + if batch_index is None: + batch_index = self.current_batch_index + 1 + samples = samples_from_arrays(structures, rewards, batch_index, metadata) + self.add_samples(samples) + + def head(self, n): + """Returns new Population containing first n samples.""" + + return Population(self.samples[:n]) + + @staticmethod + def from_dataset(dataset): + """Converts dataset of DatasetSample to Population.""" + samples = utils.arrays_from_dataset(dataset) + return Population.from_arrays(samples.structure, samples.reward) + + def to_dataset(self): + """Converts the population to a `tf.data.Dataset` with `DatasetSample`s.""" + structures, rewards = self.to_structures_and_rewards() + return utils.dataset_from_tensors( + DatasetSample(structure=structures, reward=rewards) + ) + + @staticmethod + def from_tfrecord(filename): + """Reads Population from tfrecord file.""" + raw_dataset = tf.data.TFRecordDataset([filename]) + parsed_dataset = raw_dataset.map(parse_tf_example) + + def _record_to_dict(record): + mapping = {key: utils.to_array(value) for key, value in record.items()} + if "batch_index" in mapping: + mapping["batch_index"] = int(mapping["batch_index"]) + return mapping + + return Population(Sample(**_record_to_dict(r)) for r in parsed_dataset) + + def to_tfrecord(self, filename): + """Writes Population to tfrecord file.""" + + with tf.python_io.TFRecordWriter(filename) as writer: + for sample in self.samples: + writer.write(sample.to_tfexample().SerializeToString()) + + def to_frame(self): + """Converts a `Population` to a `pd.DataFrame`.""" + records = [sample.to_dict() for sample in self.samples] + return pd.DataFrame.from_records(records) + + @classmethod + def from_frame(cls, frame): + """Converts a `pd.DataFrame` to a `Population`.""" + if frame.empty: + return cls() + population = cls() + for _, row in frame.iterrows(): + sample = Sample(**row.to_dict()) + population.add_sample(sample) + return population + + def to_csv(self, path, domain=None): + """Stores a population to a CSV file. + + Args: + path: The output file path. + domain: An optional `domains.Domain`. If provided, will also store + decoded structures in the CSV file. + """ + population_frame_to_csv(self.to_frame(), path, domain=domain) + + @classmethod + def from_csv(cls, path): + """Restores a population from a CSV file. + + Args: + path: The CSV file path. + + Returns: + An instance of a `Population`. + """ + return cls.from_frame( + population_frame_from_csv(path).drop( + columns=["decoded_structure"], errors="ignore" + ) + ) + + def best_n(self, n=1, q=None, discard_duplicates=False, blacklist=None): + """Returns the best n samples. + + Note that ties are broken deterministically. + + Args: + n: Max number to return + q: A float in (0, 1) corresponding to the minimum quantile for selecting + samples. If provided, `n` is ignored and samples with a reward >= + this quantile are selected. + discard_duplicates: If True, when several samples have the same structure, + return only one of them (the selected one is unspecified). + blacklist: Iterable of structures that should be excluded. + + Returns: + Population containing the best n Samples, sorted in decreasing order of + reward (output[0] is the best). Returns less than n if there are fewer + than n Samples in the population. + """ + if self.empty: + raise ValueError("Population empty.") + + samples = self.samples + if blacklist: + samples = self._filter(samples, blacklist) + + # are unique. + if discard_duplicates and len(samples) > 1: + samples = utils.deduplicate_samples(samples) + + samples = sorted(samples, key=lambda sample: sample.reward, reverse=True) + if q is not None: + q_value = np.quantile([sample.reward for sample in samples], q) + return Population( + [sample for sample in samples if sample.reward >= q_value] + ) + else: + return Population(samples[:n]) + + def best(self, blacklist=None): + return self.best_n(1, blacklist=blacklist).samples[0] + + def _filter(self, samples, blacklist): + blacklist = set(utils.hash_structure(structure) for structure in blacklist) + return Population( + sample + for sample in samples + if utils.hash_structure(sample.structure) not in blacklist + ) + + def contains_structure(self, structure): + return self.contains_structures([structure])[0] + + # TODO(dbelanger): this has computational overhead because it hashes the + # entire population. If this function is being called often, the set of hashed + # structures should be built up incrementally as elements are added to the + # population. Right now, the function is rarely called. + # TODO(ddohan): Build this just in time: track what hasn't been hashed + # since last call, and add any new structures before checking contain. + def contains_structures(self, structures): + structures_in_population = set( + utils.hash_structure(sample.structure) for sample in self.samples + ) + return [ + utils.hash_structure(structure) in structures_in_population + for structure in structures + ] + + def deduplicate(self, select_best=False): + """De-duplicates Samples with identical structures. + + Args: + select_best: Whether to select the sample with the highest reward among + samples with the same structure. Otherwise, the sample that occurs + first will be selected. + + Returns: + A Population with de-duplicated samples. + """ + return Population(utils.deduplicate_samples(self, select_best=select_best)) + + def discard_infeasible(self): + """Returns a new `Population` with all infeasible samples removed.""" + return Population(sample for sample in self if not sample.infeasible) diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/deep_evolution_solver.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/deep_evolution_solver.py index cdf13de..dfc19fe 100644 --- a/src/poli_baselines/solvers/bayesian_optimization/amortized/deep_evolution_solver.py +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/deep_evolution_solver.py @@ -34,8 +34,8 @@ """ import functools -#from absl import logging -#import gin +# from absl import logging +# import gin import jax from jax.example_libraries import stax from jax.example_libraries.optimizers import adam @@ -50,929 +50,962 @@ def logsoftmax(x, axis=-1): - """Apply log softmax to an array of logits, log-normalizing along an axis.""" - return x - logsumexp(x, axis, keepdims=True) + """Apply log softmax to an array of logits, log-normalizing along an axis.""" + return x - logsumexp(x, axis, keepdims=True) def softmax(x, axis=-1): - return jnp.exp(logsoftmax(x, axis)) + return jnp.exp(logsoftmax(x, axis)) def one_hot(x, k): - """Create a one-hot encoding of x of size k.""" - return jnp.eye(k)[x] + """Create a one-hot encoding of x of size k.""" + return jnp.eye(k)[x] def gumbel_max_sampler(logits, temperature, rng): - """Sample fom categorical distribution using Gumbel-Max trick. + """Sample fom categorical distribution using Gumbel-Max trick. - Gumbel-Max trick: - https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ - https://arxiv.org/abs/1411.0030 + Gumbel-Max trick: + https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ + https://arxiv.org/abs/1411.0030 - Args: - logits: Unnormalized logits for categorical distribution. - [batch x n_mutations_to_sample x n_mutation_types] - temperature: temperature parameter for Gumbel-Max. The lower the - temperature, the closer the sample is to one-hot-encoding. - rng: Jax random number generator + Args: + logits: Unnormalized logits for categorical distribution. + [batch x n_mutations_to_sample x n_mutation_types] + temperature: temperature parameter for Gumbel-Max. The lower the + temperature, the closer the sample is to one-hot-encoding. + rng: Jax random number generator - Returns: - class_assignments: Sampled class assignments [batch] - log_likelihoods: Log-likelihoods of the sampled mutations [batch] - """ + Returns: + class_assignments: Sampled class assignments [batch] + log_likelihoods: Log-likelihoods of the sampled mutations [batch] + """ - # Normalize the logits - logits = logsoftmax(logits) + # Normalize the logits + logits = logsoftmax(logits) - gumbel_noise = jrand.gumbel(rng, logits.shape) - softmax_logits = (logits + gumbel_noise) / temperature - soft_assignments = softmax(softmax_logits, -1) - class_assignments = jnp.argmax(soft_assignments, -1) - assert len(class_assignments.shape) == 2 - # Output shape: [batch x num_mutations] + gumbel_noise = jrand.gumbel(rng, logits.shape) + softmax_logits = (logits + gumbel_noise) / temperature + soft_assignments = softmax(softmax_logits, -1) + class_assignments = jnp.argmax(soft_assignments, -1) + assert len(class_assignments.shape) == 2 + # Output shape: [batch x num_mutations] - return class_assignments + return class_assignments ########################################## # Mutation-related helper functions def _mutate_position(structure, pos_mask, fn): - """Apply mutation fn to position specified by pos_mask.""" - structure = np.array(structure).copy() - pos_mask = np.array(pos_mask).astype(int) - structure[pos_mask == 1] = fn(structure[pos_mask == 1]) - return structure + """Apply mutation fn to position specified by pos_mask.""" + structure = np.array(structure).copy() + pos_mask = np.array(pos_mask).astype(int) + structure[pos_mask == 1] = fn(structure[pos_mask == 1]) + return structure def set_pos(x, pos_mask, val): - return _mutate_position(x, pos_mask, fn=lambda x: val) - - -def apply_mutations(samples, mutation_types, pos_masks, mutations, - use_assignment_mutations=False): - """Apply the mutations specified by mutation types to the batch of strings. - - Args: - samples: Batch of strings [batch x str_length] - mutation_types: IDs of mutation types to be applied to each string - [Batch x num_mutations] - pos_masks: One-hot encoding [Batch x num_mutations x str_length] - of the positions to be mutate in each string. - "num_mutations" positions will be mutated per string. - mutations: A list of possible mutation functions. - Functions should follow the format: fn(x, domain, pos_mask), - use_assignment_mutations: bool. Whether mutations are defined as - "Set position X to character C". If use_assignment_mutations=True, - then vectorize procedure of applying mutations to the string. - The index of mutation type should be equal to the index of the character. - Gives considerable speed-up to this function. - - Returns: - perturbed_samples: Strings perturbed according to the mutation list. - """ - batch_size = samples.shape[0] - assert len(mutation_types) == batch_size - assert len(pos_masks) == batch_size - - str_length = samples.shape[1] - assert pos_masks.shape[-1] == str_length - - # Check that number of mutations is consistent in mutation_types and positions - assert mutation_types.shape[1] == pos_masks.shape[1] - - num_mutations = mutation_types.shape[1] - - # List of batched samples with 0,1,2,... mutations - # First element of the list contains original samples - # Last element has samples with all mutations applied to the string - perturbed_samples_with_i_mutations = [samples] - for i in range(num_mutations): - - perturbed_samples = [] - samples_to_perturb = perturbed_samples_with_i_mutations[-1] - - if use_assignment_mutations: - perturbed_samples = samples_to_perturb.copy() - mask = pos_masks[:, i].astype(int) - # Assumes mutations are defined as "Set position to the character C" - perturbed_samples[np.array(mask) == 1] = mutation_types[:, i] - else: - for j in range(batch_size): - sample = samples_to_perturb[j].copy() + return _mutate_position(x, pos_mask, fn=lambda x: val) + + +def apply_mutations( + samples, mutation_types, pos_masks, mutations, use_assignment_mutations=False +): + """Apply the mutations specified by mutation types to the batch of strings. + + Args: + samples: Batch of strings [batch x str_length] + mutation_types: IDs of mutation types to be applied to each string + [Batch x num_mutations] + pos_masks: One-hot encoding [Batch x num_mutations x str_length] + of the positions to be mutate in each string. + "num_mutations" positions will be mutated per string. + mutations: A list of possible mutation functions. + Functions should follow the format: fn(x, domain, pos_mask), + use_assignment_mutations: bool. Whether mutations are defined as + "Set position X to character C". If use_assignment_mutations=True, + then vectorize procedure of applying mutations to the string. + The index of mutation type should be equal to the index of the character. + Gives considerable speed-up to this function. + + Returns: + perturbed_samples: Strings perturbed according to the mutation list. + """ + batch_size = samples.shape[0] + assert len(mutation_types) == batch_size + assert len(pos_masks) == batch_size + + str_length = samples.shape[1] + assert pos_masks.shape[-1] == str_length - pos = pos_masks[j, i] - mut_id = mutation_types[j, i] + # Check that number of mutations is consistent in mutation_types and positions + assert mutation_types.shape[1] == pos_masks.shape[1] - mutation = mutations[int(mut_id)] - perturbed_samples.append(mutation(sample, pos)) - perturbed_samples = np.stack(perturbed_samples) + num_mutations = mutation_types.shape[1] - assert perturbed_samples.shape == samples.shape - perturbed_samples_with_i_mutations.append(perturbed_samples) + # List of batched samples with 0,1,2,... mutations + # First element of the list contains original samples + # Last element has samples with all mutations applied to the string + perturbed_samples_with_i_mutations = [samples] + for i in range(num_mutations): - states = jnp.stack(perturbed_samples_with_i_mutations, 0) - assert states.shape == (num_mutations + 1,) + samples.shape - return states + perturbed_samples = [] + samples_to_perturb = perturbed_samples_with_i_mutations[-1] + + if use_assignment_mutations: + perturbed_samples = samples_to_perturb.copy() + mask = pos_masks[:, i].astype(int) + # Assumes mutations are defined as "Set position to the character C" + perturbed_samples[np.array(mask) == 1] = mutation_types[:, i] + else: + for j in range(batch_size): + sample = samples_to_perturb[j].copy() + + pos = pos_masks[j, i] + mut_id = mutation_types[j, i] + + mutation = mutations[int(mut_id)] + perturbed_samples.append(mutation(sample, pos)) + perturbed_samples = np.stack(perturbed_samples) + + assert perturbed_samples.shape == samples.shape + perturbed_samples_with_i_mutations.append(perturbed_samples) + + states = jnp.stack(perturbed_samples_with_i_mutations, 0) + assert states.shape == (num_mutations + 1,) + samples.shape + return states ########################################## # pylint: disable=invalid-name def OneHot(depth): - """Layer for transforming inputs to one-hot encoding.""" + """Layer for transforming inputs to one-hot encoding.""" - def init_fun(rng, input_shape): - del rng - return input_shape + (depth,), () + def init_fun(rng, input_shape): + del rng + return input_shape + (depth,), () - def apply_fun(params, inputs, **kwargs): - del params, kwargs - # Perform one-hot encoding - return jnp.eye(depth)[inputs.astype(int)] + def apply_fun(params, inputs, **kwargs): + del params, kwargs + # Perform one-hot encoding + return jnp.eye(depth)[inputs.astype(int)] - return init_fun, apply_fun + return init_fun, apply_fun def ExpandDims(axis=1): - """Layer for expanding dimensions.""" - def init_fun(rng, input_shape): - del rng - input_shape = tuple(input_shape) - if axis < 0: - dims = len(input_shape) - new_axis = dims + 1 - axis - else: - new_axis = axis - return (input_shape[:new_axis] + (1,) + input_shape[new_axis:]), () - def apply_fun(params, inputs, **kwargs): - del params, kwargs - return jnp.expand_dims(inputs, axis) - return init_fun, apply_fun + """Layer for expanding dimensions.""" + + def init_fun(rng, input_shape): + del rng + input_shape = tuple(input_shape) + if axis < 0: + dims = len(input_shape) + new_axis = dims + 1 - axis + else: + new_axis = axis + return (input_shape[:new_axis] + (1,) + input_shape[new_axis:]), () + + def apply_fun(params, inputs, **kwargs): + del params, kwargs + return jnp.expand_dims(inputs, axis) + + return init_fun, apply_fun def AssertNonZeroShape(): - """Layer for checking that no dimension has zero length.""" + """Layer for checking that no dimension has zero length.""" - def init_fun(rng, input_shape): - del rng - return input_shape, () + def init_fun(rng, input_shape): + del rng + return input_shape, () - def apply_fun(params, inputs, **kwargs): - del params, kwargs - assert 0 not in inputs.shape - return inputs + def apply_fun(params, inputs, **kwargs): + del params, kwargs + assert 0 not in inputs.shape + return inputs + + return init_fun, apply_fun - return init_fun, apply_fun # pylint: enable=invalid-name def squeeze_layer(axis=1): - """Layer for squeezing dimension along the axis.""" - def init_fun(rng, input_shape): - del rng - if axis < 0: - raise ValueError("squeeze_layer: negative axis is not supported") - return (input_shape[:axis] + input_shape[(axis + 1):]), () + """Layer for squeezing dimension along the axis.""" + + def init_fun(rng, input_shape): + del rng + if axis < 0: + raise ValueError("squeeze_layer: negative axis is not supported") + return (input_shape[:axis] + input_shape[(axis + 1) :]), () - def apply_fun(params, inputs, **kwargs): - del params, kwargs - return inputs.squeeze(axis) + def apply_fun(params, inputs, **kwargs): + del params, kwargs + return inputs.squeeze(axis) - return init_fun, apply_fun + return init_fun, apply_fun def reduce_layer(reduce_fn=jnp.mean, axis=1): - """Apply reduction function to the array along axis.""" - def init_fun(rng, input_shape): - del rng - assert axis >= 0 - assert len(input_shape) == 3 - return input_shape[:axis - 1] + input_shape[axis + 1:], () + """Apply reduction function to the array along axis.""" - def apply_fun(params, inputs, **kwargs): - del params, kwargs - return reduce_fn(inputs, axis=axis) + def init_fun(rng, input_shape): + del rng + assert axis >= 0 + assert len(input_shape) == 3 + return input_shape[: axis - 1] + input_shape[axis + 1 :], () - return init_fun, apply_fun + def apply_fun(params, inputs, **kwargs): + del params, kwargs + return reduce_fn(inputs, axis=axis) + + return init_fun, apply_fun def _create_positional_encoding( # pylint: disable=invalid-name - input_shape, max_len=10000): - """Helper: create positional encoding parameters.""" - d_feature = input_shape[-1] - pe = np.zeros((max_len, d_feature), dtype=np.float32) - position = np.arange(0, max_len)[:, np.newaxis] - div_term = np.exp(np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature)) - pe[:, 0::2] = np.sin(position * div_term) - pe[:, 1::2] = np.cos(position * div_term) - pe = pe[np.newaxis, :, :] # [1, max_len, d_feature] - return jnp.array(pe) # These are trainable parameters, initialized as above. + input_shape, max_len=10000 +): + """Helper: create positional encoding parameters.""" + d_feature = input_shape[-1] + pe = np.zeros((max_len, d_feature), dtype=np.float32) + position = np.arange(0, max_len)[:, np.newaxis] + div_term = np.exp(np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature)) + pe[:, 0::2] = np.sin(position * div_term) + pe[:, 1::2] = np.cos(position * div_term) + pe = pe[np.newaxis, :, :] # [1, max_len, d_feature] + return jnp.array(pe) # These are trainable parameters, initialized as above. def positional_encoding(): - """Concatenate positional encoding to the last dimension.""" - def init_fun(rng, input_shape): - del rng - input_shape_for_enc = input_shape - params = _create_positional_encoding(input_shape_for_enc) - last_dim = input_shape[-1] + params.shape[-1] - return input_shape[:-1] + (last_dim,), (params,) - - def apply_fun(params, inputs, **kwargs): - del kwargs - assert inputs.ndim == 4 - params = params[0] - symbol_size = inputs.shape[-2] - enc = params[None, :, :symbol_size, :] - enc = jnp.repeat(enc, inputs.shape[0], 0) - return jnp.concatenate((inputs, enc), -1) + """Concatenate positional encoding to the last dimension.""" + + def init_fun(rng, input_shape): + del rng + input_shape_for_enc = input_shape + params = _create_positional_encoding(input_shape_for_enc) + last_dim = input_shape[-1] + params.shape[-1] + return input_shape[:-1] + (last_dim,), (params,) + + def apply_fun(params, inputs, **kwargs): + del kwargs + assert inputs.ndim == 4 + params = params[0] + symbol_size = inputs.shape[-2] + enc = params[None, :, :symbol_size, :] + enc = jnp.repeat(enc, inputs.shape[0], 0) + return jnp.concatenate((inputs, enc), -1) + + return init_fun, apply_fun + + +def cnn( + conv_depth=300, + kernel_size=5, + n_conv_layers=2, + across_batch=False, + add_pos_encoding=False, +): + """Build convolutional neural net.""" + # Input shape: [batch x length x depth] + if across_batch: + extra_dim = 0 + else: + extra_dim = 1 + layers = [ExpandDims(axis=extra_dim)] + if add_pos_encoding: + layers.append(positional_encoding()) + + for _ in range(n_conv_layers): + layers.append( + stax.Conv(conv_depth, (1, kernel_size), padding="same", strides=(1, 1)) + ) + layers.append(stax.Relu) + layers.append(AssertNonZeroShape()) + layers.append(squeeze_layer(axis=extra_dim)) + return stax.serial(*layers) + + +def build_model_stax( + output_size, + n_dense_units=300, + conv_depth=300, + n_conv_layers=2, + n_dense_layers=0, + kernel_size=5, + across_batch=False, + add_pos_encoding=False, + mean_over_pos=False, + mode="train", +): + """Build a model with convolutional layers followed by dense layers.""" + del mode + layers = [ + cnn( + conv_depth=conv_depth, + n_conv_layers=n_conv_layers, + kernel_size=kernel_size, + across_batch=across_batch, + add_pos_encoding=add_pos_encoding, + ) + ] + for _ in range(n_dense_layers): + layers.append(stax.Dense(n_dense_units)) + layers.append(stax.Relu) - return init_fun, apply_fun + layers.append(stax.Dense(output_size)) + if mean_over_pos: + layers.append(reduce_layer(jnp.mean, axis=1)) + init_random_params, predict = stax.serial(*layers) + return init_random_params, predict -def cnn(conv_depth=300, - kernel_size=5, - n_conv_layers=2, - across_batch=False, - add_pos_encoding=False): - """Build convolutional neural net.""" - # Input shape: [batch x length x depth] - if across_batch: - extra_dim = 0 - else: - extra_dim = 1 - layers = [ExpandDims(axis=extra_dim)] - if add_pos_encoding: - layers.append(positional_encoding()) - - for _ in range(n_conv_layers): - layers.append( - stax.Conv(conv_depth, (1, kernel_size), padding="same", strides=(1, 1))) - layers.append(stax.Relu) - layers.append(AssertNonZeroShape()) - layers.append(squeeze_layer(axis=extra_dim)) - return stax.serial(*layers) - - -def build_model_stax(output_size, - n_dense_units=300, - conv_depth=300, - n_conv_layers=2, - n_dense_layers=0, - kernel_size=5, - across_batch=False, - add_pos_encoding=False, - mean_over_pos=False, - mode="train"): - """Build a model with convolutional layers followed by dense layers.""" - del mode - layers = [ - cnn(conv_depth=conv_depth, - n_conv_layers=n_conv_layers, - kernel_size=kernel_size, - across_batch=across_batch, - add_pos_encoding=add_pos_encoding) - ] - for _ in range(n_dense_layers): - layers.append(stax.Dense(n_dense_units)) - layers.append(stax.Relu) - - layers.append(stax.Dense(output_size)) - - if mean_over_pos: - layers.append(reduce_layer(jnp.mean, axis=1)) - init_random_params, predict = stax.serial(*layers) - return init_random_params, predict - - -def sample_log_probs_top_k(log_probs, rng, temperature=1., k=1): - """Sample categorical distribution of log probs using gumbel max trick.""" - noise = jax.random.gumbel(rng, shape=log_probs.shape) - perturbed = (log_probs + noise) / temperature - samples = jnp.argsort(perturbed)[Ellipsis, -k:] - return samples + +def sample_log_probs_top_k(log_probs, rng, temperature=1.0, k=1): + """Sample categorical distribution of log probs using gumbel max trick.""" + noise = jax.random.gumbel(rng, shape=log_probs.shape) + perturbed = (log_probs + noise) / temperature + samples = jnp.argsort(perturbed)[Ellipsis, -k:] + return samples @jax.jit def gather_positions(idx_to_gather, logits): - """Collect logits corresponding to the positions in the string. + """Collect logits corresponding to the positions in the string. - Used for collecting logits for: - 1) positions in the string (depth = 1) - 2) mutation types (depth = n_mut_types) + Used for collecting logits for: + 1) positions in the string (depth = 1) + 2) mutation types (depth = n_mut_types) - Args: - idx_to_gather: [batch_size x num_mutations] Indices of the positions - in the string to gather logits for. - logits: [batch_size x str_length x depth] Logits to index. + Args: + idx_to_gather: [batch_size x num_mutations] Indices of the positions + in the string to gather logits for. + logits: [batch_size x str_length x depth] Logits to index. - Returns: - Logits corresponding to the specified positions in the string: - [batch_size, num_mutations, depth] - """ - assert idx_to_gather.shape[0] == logits.shape[0] - assert idx_to_gather.ndim == 2 - assert logits.ndim == 3 + Returns: + Logits corresponding to the specified positions in the string: + [batch_size, num_mutations, depth] + """ + assert idx_to_gather.shape[0] == logits.shape[0] + assert idx_to_gather.ndim == 2 + assert logits.ndim == 3 - batch_size, num_mutations = idx_to_gather.shape - batch_size, str_length, depth = logits.shape + batch_size, num_mutations = idx_to_gather.shape + batch_size, str_length, depth = logits.shape - oh = one_hot(idx_to_gather, str_length) - assert oh.shape == (batch_size, num_mutations, str_length) + oh = one_hot(idx_to_gather, str_length) + assert oh.shape == (batch_size, num_mutations, str_length) - oh = oh[Ellipsis, None] - logits = logits[:, None, :, :] - assert oh.shape == (batch_size, num_mutations, str_length, 1) - assert logits.shape == (batch_size, 1, str_length, depth) + oh = oh[Ellipsis, None] + logits = logits[:, None, :, :] + assert oh.shape == (batch_size, num_mutations, str_length, 1) + assert logits.shape == (batch_size, 1, str_length, depth) - # Perform element-wise multiplication (with broadcasting), - # then sum over str_length dimension - result = jnp.sum(oh * logits, axis=-2) - assert result.shape == (batch_size, num_mutations, depth) - return result + # Perform element-wise multiplication (with broadcasting), + # then sum over str_length dimension + result = jnp.sum(oh * logits, axis=-2) + assert result.shape == (batch_size, num_mutations, depth) + return result class JaxMutationPredictor(object): - """Implements training and predicting from a Jax model. - - Attributes: - output_size: Tuple containing the sizes of components to predict - loss_fn: Loss function. - Format of the loss fn: fn(params, batch, mutations, problem, predictor) - loss_grad_fn: Gradient of the loss function - temperature: temperature parameter for Gumbel-Max sampler. - learning_rate: Learning rate for optimizer. - batch_size: Batch size of input - model_fn: Function which builds the model forward pass. Must have arguments - `vocab_size`, `max_len`, and `mode` and return Jax float arrays. - params: weights of the neural net - make_state: function to make optimizer state given the network parameters - rng: Jax random number generator - """ - - def __init__(self, - vocab_size, - output_size, - loss_fn, - rng, - temperature=1, - learning_rate=0.001, - conv_depth=300, - n_conv_layers=2, - n_dense_units=300, - n_dense_layers=0, - kernel_size=5, - across_batch=False, - add_pos_encoding=False, - mean_over_pos=False, - model_fn=build_model_stax): - self.output_size = output_size - self.temperature = temperature - - # Setup randomness. - self.rng = rng - - model_settings = { - "output_size": output_size, - "n_dense_units": n_dense_units, - "n_dense_layers": n_dense_layers, - "conv_depth": conv_depth, - "n_conv_layers": n_conv_layers, - "across_batch": across_batch, - "kernel_size": kernel_size, - "add_pos_encoding": add_pos_encoding, - "mean_over_pos": mean_over_pos, - "mode": "train" - } - - self._model_init, model_train = model_fn(**model_settings) - self._model_train = jax.jit(model_train) - - model_settings["mode"] = "eval" - _, model_predict = model_fn(**model_settings) - self._model_predict = jax.jit(model_predict) - - self.rng, subrng = jrand.split(self.rng) - _, init_params = self._model_init(subrng, (-1, -1, vocab_size)) - self.params = init_params - - # Setup parameters for model and optimizer - self.make_state, self._opt_update_state, self._get_params = adam( - learning_rate) - - self.loss_fn = functools.partial(loss_fn, run_model_fn=self.run_model) - self.loss_grad_fn = jax.grad(self.loss_fn) - - # Track steps of optimization so far. - self._step_idx = 0 - - def update_step(self, rewards, inputs, actions): - """Performs a single update step on a batch of samples. - - Args: - rewards: Batch [batch] of rewards for perturbed samples. - inputs: Batch [batch x length] of original samples - actions: actions applied on the samples - - Raises: - ValueError: if any inputs are the wrong shape. + """Implements training and predicting from a Jax model. + + Attributes: + output_size: Tuple containing the sizes of components to predict + loss_fn: Loss function. + Format of the loss fn: fn(params, batch, mutations, problem, predictor) + loss_grad_fn: Gradient of the loss function + temperature: temperature parameter for Gumbel-Max sampler. + learning_rate: Learning rate for optimizer. + batch_size: Batch size of input + model_fn: Function which builds the model forward pass. Must have arguments + `vocab_size`, `max_len`, and `mode` and return Jax float arrays. + params: weights of the neural net + make_state: function to make optimizer state given the network parameters + rng: Jax random number generator """ - grad_update = self.loss_grad_fn( - self.params, - rewards=rewards, - inputs=inputs, - actions=actions, - ) + def __init__( + self, + vocab_size, + output_size, + loss_fn, + rng, + temperature=1, + learning_rate=0.001, + conv_depth=300, + n_conv_layers=2, + n_dense_units=300, + n_dense_layers=0, + kernel_size=5, + across_batch=False, + add_pos_encoding=False, + mean_over_pos=False, + model_fn=build_model_stax, + ): + self.output_size = output_size + self.temperature = temperature + + # Setup randomness. + self.rng = rng + + model_settings = { + "output_size": output_size, + "n_dense_units": n_dense_units, + "n_dense_layers": n_dense_layers, + "conv_depth": conv_depth, + "n_conv_layers": n_conv_layers, + "across_batch": across_batch, + "kernel_size": kernel_size, + "add_pos_encoding": add_pos_encoding, + "mean_over_pos": mean_over_pos, + "mode": "train", + } + + self._model_init, model_train = model_fn(**model_settings) + self._model_train = jax.jit(model_train) + + model_settings["mode"] = "eval" + _, model_predict = model_fn(**model_settings) + self._model_predict = jax.jit(model_predict) + + self.rng, subrng = jrand.split(self.rng) + _, init_params = self._model_init(subrng, (-1, -1, vocab_size)) + self.params = init_params + + # Setup parameters for model and optimizer + self.make_state, self._opt_update_state, self._get_params = adam(learning_rate) + + self.loss_fn = functools.partial(loss_fn, run_model_fn=self.run_model) + self.loss_grad_fn = jax.grad(self.loss_fn) + + # Track steps of optimization so far. + self._step_idx = 0 + + def update_step(self, rewards, inputs, actions): + """Performs a single update step on a batch of samples. + + Args: + rewards: Batch [batch] of rewards for perturbed samples. + inputs: Batch [batch x length] of original samples + actions: actions applied on the samples + + Raises: + ValueError: if any inputs are the wrong shape. + """ + + grad_update = self.loss_grad_fn( + self.params, + rewards=rewards, + inputs=inputs, + actions=actions, + ) - old_params = self.params - state = self.make_state(old_params) - state = self._opt_update_state(self._step_idx, grad_update, state) + old_params = self.params + state = self.make_state(old_params) + state = self._opt_update_state(self._step_idx, grad_update, state) - self.params = self._get_params(state) - del old_params, state + self.params = self._get_params(state) + del old_params, state - self._step_idx += 1 + self._step_idx += 1 - def __call__(self, x, mode="eval"): - """Calls predict function of model. + def __call__(self, x, mode="eval"): + """Calls predict function of model. - Args: - x: Batch of input samples. - mode: Mode for running the network: "train" or "eval" + Args: + x: Batch of input samples. + mode: Mode for running the network: "train" or "eval" - Returns: - A list of tuples (class weights, log likelihood) for each of - output components predicted by the model. - """ - return self.run_model(x, self.params, mode="eval") + Returns: + A list of tuples (class weights, log likelihood) for each of + output components predicted by the model. + """ + return self.run_model(x, self.params, mode="eval") - def run_model(self, x, params, mode="eval"): - """Run the Jax model. + def run_model(self, x, params, mode="eval"): + """Run the Jax model. - This function is used in __call__ to run the model in "eval" mode - and in the loss function to run the model in "train" mode. + This function is used in __call__ to run the model in "eval" mode + and in the loss function to run the model in "train" mode. - Args: - x: Batch of input samples. - params: Network parameters - mode: Mode for running the network: "train" or "eval" + Args: + x: Batch of input samples. + params: Network parameters + mode: Mode for running the network: "train" or "eval" - Returns: - Jax neural network output. - """ - if mode == "train": - model_fn = self._model_train - else: - model_fn = self._model_predict - self.rng, subrng = jax.random.split(self.rng) - return model_fn(params, inputs=x, rng=subrng) + Returns: + Jax neural network output. + """ + if mode == "train": + model_fn = self._model_train + else: + model_fn = self._model_predict + self.rng, subrng = jax.random.split(self.rng) + return model_fn(params, inputs=x, rng=subrng) ######################################### # Loss function def reinforce_loss(rewards, log_likelihood): - """Loss function for Jax model. + """Loss function for Jax model. - Args: - rewards: List of rewards [batch] for the perturbed samples. - log_likelihood: Log-likelihood of perturbations + Args: + rewards: List of rewards [batch] for the perturbed samples. + log_likelihood: Log-likelihood of perturbations - Returns: - Scalar loss. - """ - rewards = jax.lax.stop_gradient(rewards) + Returns: + Scalar loss. + """ + rewards = jax.lax.stop_gradient(rewards) - # In general case, we assume that the loss is not differentiable - # Use REINFORCE - reinforce_estim = rewards * log_likelihood - # Take mean over the number of applied mutations, then across the batch - return -jnp.mean(jnp.mean(reinforce_estim, 1), 0) + # In general case, we assume that the loss is not differentiable + # Use REINFORCE + reinforce_estim = rewards * log_likelihood + # Take mean over the number of applied mutations, then across the batch + return -jnp.mean(jnp.mean(reinforce_estim, 1), 0) def compute_entropy(log_probs): - """Compute entropy of a set of log_probs.""" - return -jnp.mean(jnp.mean(stax.softmax(log_probs) * log_probs, axis=-1)) + """Compute entropy of a set of log_probs.""" + return -jnp.mean(jnp.mean(stax.softmax(log_probs) * log_probs, axis=-1)) def compute_advantage(params, critic_fn, rewards, inputs): - """Compute the advantage: difference between rewards and predicted value. + """Compute the advantage: difference between rewards and predicted value. - Args: - params: parameters for the critic neural net - critic_fn: function to run critic neural net - rewards: rewards for the perturbed samples - inputs: original samples, used as input to the Jax model + Args: + params: parameters for the critic neural net + critic_fn: function to run critic neural net + rewards: rewards for the perturbed samples + inputs: original samples, used as input to the Jax model - Returns: - advantage: [batch_size x num_mutations] - """ - assert inputs.ndim == 4 + Returns: + advantage: [batch_size x num_mutations] + """ + assert inputs.ndim == 4 - num_mutations, batch_size, str_length, vocab_size = inputs.shape + num_mutations, batch_size, str_length, vocab_size = inputs.shape - inputs_reshaped = inputs.reshape( - (num_mutations * batch_size, str_length, vocab_size)) + inputs_reshaped = inputs.reshape( + (num_mutations * batch_size, str_length, vocab_size) + ) - predicted_value = critic_fn(inputs_reshaped, params, mode="train") - assert predicted_value.shape == (num_mutations * batch_size, 1) - predicted_value = predicted_value.reshape((num_mutations, batch_size)) + predicted_value = critic_fn(inputs_reshaped, params, mode="train") + assert predicted_value.shape == (num_mutations * batch_size, 1) + predicted_value = predicted_value.reshape((num_mutations, batch_size)) - assert rewards.shape == (batch_size,) - rewards = jnp.repeat(rewards[None, :], num_mutations, 0) - assert rewards.shape == (num_mutations, batch_size) + assert rewards.shape == (batch_size,) + rewards = jnp.repeat(rewards[None, :], num_mutations, 0) + assert rewards.shape == (num_mutations, batch_size) - advantage = rewards - predicted_value - advantage = jnp.transpose(advantage) - assert advantage.shape == (batch_size, num_mutations) - return advantage + advantage = rewards - predicted_value + advantage = jnp.transpose(advantage) + assert advantage.shape == (batch_size, num_mutations) + return advantage def value_loss_fn(params, run_model_fn, rewards, inputs, actions=None): - """Compute the loss for the value function. + """Compute the loss for the value function. - Args: - params: parameters for the Jax model - run_model_fn: Jax model to run - rewards: rewards for the perturbed samples - inputs: original samples, used as input to the Jax model - actions: not used + Args: + params: parameters for the Jax model + run_model_fn: Jax model to run + rewards: rewards for the perturbed samples + inputs: original samples, used as input to the Jax model + actions: not used - Returns: - A scalar loss. - """ - del actions - advantage = compute_advantage(params, run_model_fn, rewards, inputs) - advantage = advantage**2 + Returns: + A scalar loss. + """ + del actions + advantage = compute_advantage(params, run_model_fn, rewards, inputs) + advantage = advantage**2 - return jnp.sqrt(jnp.mean(advantage)) + return jnp.sqrt(jnp.mean(advantage)) def split_mutation_predictor_output(output): - return stax.logsoftmax(output[:, :, -1]), stax.logsoftmax(output[:, :, :-1]) + return stax.logsoftmax(output[:, :, -1]), stax.logsoftmax(output[:, :, :-1]) -def run_model_and_compute_reinforce_loss(params, - run_model_fn, - rewards, - inputs, - actions, - n_mutations, - entropy_weight=0.1): - """Run Jax model and compute REINFORCE loss. +def run_model_and_compute_reinforce_loss( + params, run_model_fn, rewards, inputs, actions, n_mutations, entropy_weight=0.1 +): + """Run Jax model and compute REINFORCE loss. - Jax can compute the gradients of the model only if the model is called inside - the loss function. Here we call the Jax model, re-compute the log-likelihoods, - take log-likelihoods of the mutations and positions sampled before in - _propose function of the solver, and compute the loss. + Jax can compute the gradients of the model only if the model is called inside + the loss function. Here we call the Jax model, re-compute the log-likelihoods, + take log-likelihoods of the mutations and positions sampled before in + _propose function of the solver, and compute the loss. - Args: - params: parameters for the Jax model - run_model_fn: Jax model to run - rewards: rewards for the perturbed samples - inputs: original samples, used as input to the Jax model - actions: Tuple (mut_types [Batch], positions [Batch]) of mutation types - and positions sampled during the _propose() step of evolution solver. - n_mutations: Number of mutations. Used for one-hot encoding of mutations - entropy_weight: Weight on the entropy term added to the loss. + Args: + params: parameters for the Jax model + run_model_fn: Jax model to run + rewards: rewards for the perturbed samples + inputs: original samples, used as input to the Jax model + actions: Tuple (mut_types [Batch], positions [Batch]) of mutation types + and positions sampled during the _propose() step of evolution solver. + n_mutations: Number of mutations. Used for one-hot encoding of mutations + entropy_weight: Weight on the entropy term added to the loss. - Returns: - A scalar loss. + Returns: + A scalar loss. - """ - mut_types, positions = actions - mut_types_one_hot = one_hot(mut_types, n_mutations) + """ + mut_types, positions = actions + mut_types_one_hot = one_hot(mut_types, n_mutations) - batch_size, str_length, _ = inputs.shape - assert mut_types.shape[0] == inputs.shape[0] - batch_size, num_mutations = mut_types.shape - assert mut_types.shape == positions.shape - assert mut_types.shape == rewards.shape + batch_size, str_length, _ = inputs.shape + assert mut_types.shape[0] == inputs.shape[0] + batch_size, num_mutations = mut_types.shape + assert mut_types.shape == positions.shape + assert mut_types.shape == rewards.shape - output = run_model_fn(inputs, params, mode="train") - pos_log_probs, all_mut_log_probs = split_mutation_predictor_output(output) - assert pos_log_probs.shape == (batch_size, str_length) - pos_log_probs = jnp.expand_dims(pos_log_probs, -1) + output = run_model_fn(inputs, params, mode="train") + pos_log_probs, all_mut_log_probs = split_mutation_predictor_output(output) + assert pos_log_probs.shape == (batch_size, str_length) + pos_log_probs = jnp.expand_dims(pos_log_probs, -1) - pos_log_likelihoods = gather_positions(positions, pos_log_probs) - assert pos_log_likelihoods.shape == (batch_size, num_mutations, 1) + pos_log_likelihoods = gather_positions(positions, pos_log_probs) + assert pos_log_likelihoods.shape == (batch_size, num_mutations, 1) - # Sum over number of positions - pos_log_likelihoods = jnp.sum(pos_log_likelihoods, -1) + # Sum over number of positions + pos_log_likelihoods = jnp.sum(pos_log_likelihoods, -1) - # all_mut_log_probs shape: [batch_size, str_length, n_mut_types] - assert all_mut_log_probs.shape[:2] == (batch_size, str_length) + # all_mut_log_probs shape: [batch_size, str_length, n_mut_types] + assert all_mut_log_probs.shape[:2] == (batch_size, str_length) - # Get mutation logits corresponding to the chosen positions - mutation_logprobs = gather_positions(positions, all_mut_log_probs) + # Get mutation logits corresponding to the chosen positions + mutation_logprobs = gather_positions(positions, all_mut_log_probs) - # Get log probs corresponding to the selected mutations - mut_log_likelihoods_oh = mutation_logprobs * mut_types_one_hot + # Get log probs corresponding to the selected mutations + mut_log_likelihoods_oh = mutation_logprobs * mut_types_one_hot - # Sum over mutation types - mut_log_likelihoods = jnp.sum(mut_log_likelihoods_oh, -1) - assert mut_log_likelihoods.shape == (batch_size, num_mutations) + # Sum over mutation types + mut_log_likelihoods = jnp.sum(mut_log_likelihoods_oh, -1) + assert mut_log_likelihoods.shape == (batch_size, num_mutations) - joint_log_likelihood = mut_log_likelihoods + pos_log_likelihoods - assert joint_log_likelihood.shape == (batch_size, num_mutations) + joint_log_likelihood = mut_log_likelihoods + pos_log_likelihoods + assert joint_log_likelihood.shape == (batch_size, num_mutations) - loss = reinforce_loss(rewards, joint_log_likelihood) - loss -= entropy_weight * compute_entropy(mutation_logprobs) - return loss + loss = reinforce_loss(rewards, joint_log_likelihood) + loss -= entropy_weight * compute_entropy(mutation_logprobs) + return loss ############################################ # MutationPredictorSolver def initialize_uniformly(domain, batch_size, random_state): - return domain.sample_uniformly(batch_size, seed=random_state) + return domain.sample_uniformly(batch_size, seed=random_state) -#@gin.configurable +# @gin.configurable class MutationPredictorSolver(base_solver.BaseSolver): - """Choose the mutation operator conditioned on the sample. + """Choose the mutation operator conditioned on the sample. - Sample from categorical distribution over available mutation operators - using Gumbel-Max trick - """ + Sample from categorical distribution over available mutation operators + using Gumbel-Max trick + """ - def __init__(self, - domain, - model_fn=build_model_stax, - random_state=0, - **kwargs): - """Constructs solver. + def __init__(self, domain, model_fn=build_model_stax, random_state=0, **kwargs): + """Constructs solver. + + Args: + domain: discrete domain + model_fn: Function which builds the forward pass of predictor model. + random_state: Random state to initialize jax & np RNGs. + **kwargs: kwargs passed to config. + """ + super(MutationPredictorSolver, self).__init__( + domain=domain, random_state=random_state, **kwargs + ) + self.rng = jrand.PRNGKey(random_state) + self.rng, rng = jax.random.split(self.rng) + + if self.domain.length < self.cfg.num_mutations: + # logging.warning("Number of mutations to perform per string exceeds string" + # " length. The number of mutation is set to be equal to " + # "the string length.") + self.cfg.num_mutations = self.domain.length + + # Right now the mutations are defined as "Set position X to character C". + # It allows to vectorize applying mutations to the string and speeds up + # the solver. + # If using other types of mutations, set self.use_assignment_mut=False. + self.mutations = [] + for val in range(self.domain.vocab_size): + self.mutations.append(functools.partial(set_pos, val=val)) + self.use_assignment_mut = True + + mut_loss_fn = functools.partial( + run_model_and_compute_reinforce_loss, n_mutations=len(self.mutations) + ) - Args: - domain: discrete domain - model_fn: Function which builds the forward pass of predictor model. - random_state: Random state to initialize jax & np RNGs. - **kwargs: kwargs passed to config. - """ - super(MutationPredictorSolver, self).__init__( - domain=domain, random_state=random_state, **kwargs) - self.rng = jrand.PRNGKey(random_state) - self.rng, rng = jax.random.split(self.rng) - - if self.domain.length < self.cfg.num_mutations: - # logging.warning("Number of mutations to perform per string exceeds string" - # " length. The number of mutation is set to be equal to " - # "the string length.") - self.cfg.num_mutations = self.domain.length - - # Right now the mutations are defined as "Set position X to character C". - # It allows to vectorize applying mutations to the string and speeds up - # the solver. - # If using other types of mutations, set self.use_assignment_mut=False. - self.mutations = [] - for val in range(self.domain.vocab_size): - self.mutations.append(functools.partial(set_pos, val=val)) - self.use_assignment_mut = True - - mut_loss_fn = functools.partial(run_model_and_compute_reinforce_loss, - n_mutations=len(self.mutations)) - - # Predictor that takes the input string - # Outputs the weights over the 1) mutations types 2) position in string - if self.cfg.pretrained_model is None: - self._mut_predictor = self.cfg.predictor( - vocab_size=self.domain.vocab_size, - output_size=len(self.mutations) + 1, - loss_fn=mut_loss_fn, - rng=rng, - model_fn=build_model_stax, - conv_depth=self.cfg.actor_conv_depth, - n_conv_layers=self.cfg.actor_n_conv_layers, - n_dense_units=self.cfg.actor_n_dense_units, - n_dense_layers=self.cfg.actor_n_dense_layers, - across_batch=self.cfg.actor_across_batch, - add_pos_encoding=self.cfg.actor_add_pos_encoding, - kernel_size=self.cfg.actor_kernel_size, - learning_rate=self.cfg.actor_learning_rate, - ) - - if self.cfg.use_actor_critic: - self._value_predictor = self.cfg.predictor( - vocab_size=self.domain.vocab_size, - output_size=1, - rng=rng, - loss_fn=value_loss_fn, - model_fn=build_model_stax, - mean_over_pos=True, - conv_depth=self.cfg.critic_conv_depth, - n_conv_layers=self.cfg.critic_n_conv_layers, - n_dense_units=self.cfg.critic_n_dense_units, - n_dense_layers=self.cfg.critic_n_dense_layers, - across_batch=self.cfg.critic_across_batch, - add_pos_encoding=self.cfg.critic_add_pos_encoding, - kernel_size=self.cfg.critic_kernel_size, - learning_rate=self.cfg.critic_learning_rate, + # Predictor that takes the input string + # Outputs the weights over the 1) mutations types 2) position in string + if self.cfg.pretrained_model is None: + self._mut_predictor = self.cfg.predictor( + vocab_size=self.domain.vocab_size, + output_size=len(self.mutations) + 1, + loss_fn=mut_loss_fn, + rng=rng, + model_fn=build_model_stax, + conv_depth=self.cfg.actor_conv_depth, + n_conv_layers=self.cfg.actor_n_conv_layers, + n_dense_units=self.cfg.actor_n_dense_units, + n_dense_layers=self.cfg.actor_n_dense_layers, + across_batch=self.cfg.actor_across_batch, + add_pos_encoding=self.cfg.actor_add_pos_encoding, + kernel_size=self.cfg.actor_kernel_size, + learning_rate=self.cfg.actor_learning_rate, + ) + + if self.cfg.use_actor_critic: + self._value_predictor = self.cfg.predictor( + vocab_size=self.domain.vocab_size, + output_size=1, + rng=rng, + loss_fn=value_loss_fn, + model_fn=build_model_stax, + mean_over_pos=True, + conv_depth=self.cfg.critic_conv_depth, + n_conv_layers=self.cfg.critic_n_conv_layers, + n_dense_units=self.cfg.critic_n_dense_units, + n_dense_layers=self.cfg.critic_n_dense_layers, + across_batch=self.cfg.critic_across_batch, + add_pos_encoding=self.cfg.critic_add_pos_encoding, + kernel_size=self.cfg.critic_kernel_size, + learning_rate=self.cfg.critic_learning_rate, + ) + + else: + self._value_predictor = None + else: + self._mut_predictor, self._value_predictor = self.cfg.pretrained_model + + self._data_for_grad_update = [] + self._initialized = False + + def _config(self): + cfg = super(MutationPredictorSolver, self)._config() + + cfg.update( + dict( + predictor=JaxMutationPredictor, + temperature=1.0, + initialize_dataset_fn=initialize_uniformly, + elite_set_size=10, + use_random_network=False, + exploit_with_best=True, + use_selection_of_best=False, + pretrained_model=None, + # Indicator to BO to pass in previous weights. + # As implemented in cl/318101597. + warmstart=True, + use_actor_critic=False, + num_mutations=5, + # Hyperparameters for actor + actor_learning_rate=0.001, + actor_conv_depth=300, + actor_n_conv_layers=1, + actor_n_dense_units=100, + actor_n_dense_layers=0, + actor_kernel_size=5, + actor_across_batch=False, + actor_add_pos_encoding=True, + # Hyperparameters for critic + critic_learning_rate=0.001, + critic_conv_depth=300, + critic_n_conv_layers=1, + critic_n_dense_units=300, + critic_n_dense_layers=0, + critic_kernel_size=5, + critic_across_batch=False, + critic_add_pos_encoding=True, + ) + ) + return cfg + + def _get_unique(self, samples): + unique_population = data.Population() + unique_structures = set() + for sample in samples: + hashed_structure = utils.hash_structure(sample.structure) + if hashed_structure in unique_structures: + continue + unique_structures.add(hashed_structure) + unique_population.add_samples([sample]) + return unique_population + + def _get_best_samples_from_last_batch( + self, population, n=1, discard_duplicates=True + ): + best_samples = population.get_last_batch().best_n( + n, discard_duplicates=discard_duplicates ) + return best_samples.structures, best_samples.rewards + + def _select(self, population): + if self.cfg.use_selection_of_best: + # Choose best samples from the previous batch + structures, rewards = self._get_best_samples_from_last_batch( + population, self.cfg.elite_set_size + ) + + # Choose the samples to perturb with replacement + idx = np.random.choice( + len(structures), self.batch_size, replace=True + ) # pytype: disable=attribute-error # trace-all-classes + selected_structures = np.stack([structures[i] for i in idx]) + selected_rewards = np.stack([rewards[i] for i in idx]) + return selected_structures, selected_rewards + else: + # Just return the samples from the previous batch -- no selection + last_batch = population.get_last_batch() + structures = np.array([x.structure for x in last_batch]) + rewards = np.array([x.reward for x in last_batch]) + + if ( + len(last_batch) > self.batch_size + ): # pytype: disable=attribute-error # trace-all-classes + # Subsample the data + idx = np.random.choice( + len(last_batch), self.batch_size, replace=False + ) # pytype: disable=attribute-error # trace-all-classes + structures = np.stack([structures[i] for i in idx]) + rewards = np.stack([rewards[i] for i in idx]) + return structures, rewards + + def propose( + self, num_samples, population=None, pending_samples=None + ): # pytype: disable=signature-mismatch # overriding-parameter-count-checks + # Initialize population randomly. + if self._initialized and population: + if num_samples != self.batch_size: + raise ValueError("Must maintain constant batch size between runs.") + counter = population.max_batch_index + if counter > 0: + if not self.cfg.use_random_network: + self._update_params(population) + else: + self.batch_size = num_samples + self._initialized = True + return self.cfg.initialize_dataset_fn( + self.domain, num_samples, random_state=self._random_state + ) + + # Choose best samples so far -- [elite_set_size] + samples_to_perturb, parent_rewards = self._select(population) + + perturbed, actions, mut_predictor_input = self._perturb(samples_to_perturb) - else: - self._value_predictor = None - else: - self._mut_predictor, self._value_predictor = self.cfg.pretrained_model - - self._data_for_grad_update = [] - self._initialized = False - - def _config(self): - cfg = super(MutationPredictorSolver, self)._config() - - cfg.update( - dict( - predictor=JaxMutationPredictor, - temperature=1., - initialize_dataset_fn=initialize_uniformly, - elite_set_size=10, - use_random_network=False, - exploit_with_best=True, - use_selection_of_best=False, - pretrained_model=None, - - # Indicator to BO to pass in previous weights. - # As implemented in cl/318101597. - warmstart=True, - - use_actor_critic=False, - num_mutations=5, - - # Hyperparameters for actor - actor_learning_rate=0.001, - actor_conv_depth=300, - actor_n_conv_layers=1, - actor_n_dense_units=100, - actor_n_dense_layers=0, - actor_kernel_size=5, - actor_across_batch=False, - actor_add_pos_encoding=True, - - # Hyperparameters for critic - critic_learning_rate=0.001, - critic_conv_depth=300, - critic_n_conv_layers=1, - critic_n_dense_units=300, - critic_n_dense_layers=0, - critic_kernel_size=5, - critic_across_batch=False, - critic_add_pos_encoding=True, - )) - return cfg - - def _get_unique(self, samples): - unique_population = data.Population() - unique_structures = set() - for sample in samples: - hashed_structure = utils.hash_structure(sample.structure) - if hashed_structure in unique_structures: - continue - unique_structures.add(hashed_structure) - unique_population.add_samples([sample]) - return unique_population - - def _get_best_samples_from_last_batch(self, - population, - n=1, - discard_duplicates=True): - best_samples = population.get_last_batch().best_n( - n, discard_duplicates=discard_duplicates) - return best_samples.structures, best_samples.rewards - - def _select(self, population): - if self.cfg.use_selection_of_best: - # Choose best samples from the previous batch - structures, rewards = self._get_best_samples_from_last_batch( - population, self.cfg.elite_set_size) - - # Choose the samples to perturb with replacement - idx = np.random.choice(len(structures), self.batch_size, replace=True) # pytype: disable=attribute-error # trace-all-classes - selected_structures = np.stack([structures[i] for i in idx]) - selected_rewards = np.stack([rewards[i] for i in idx]) - return selected_structures, selected_rewards - else: - # Just return the samples from the previous batch -- no selection - last_batch = population.get_last_batch() - structures = np.array([x.structure for x in last_batch]) - rewards = np.array([x.reward for x in last_batch]) - - if len(last_batch) > self.batch_size: # pytype: disable=attribute-error # trace-all-classes - # Subsample the data - idx = np.random.choice(len(last_batch), self.batch_size, replace=False) # pytype: disable=attribute-error # trace-all-classes - structures = np.stack([structures[i] for i in idx]) - rewards = np.stack([rewards[i] for i in idx]) - return structures, rewards - - def propose(self, num_samples, population=None, pending_samples=None): # pytype: disable=signature-mismatch # overriding-parameter-count-checks - # Initialize population randomly. - if self._initialized and population: - if num_samples != self.batch_size: - raise ValueError("Must maintain constant batch size between runs.") - counter = population.max_batch_index - if counter > 0: if not self.cfg.use_random_network: - self._update_params(population) - else: - self.batch_size = num_samples - self._initialized = True - return self.cfg.initialize_dataset_fn( - self.domain, num_samples, random_state=self._random_state) - - # Choose best samples so far -- [elite_set_size] - samples_to_perturb, parent_rewards = self._select(population) - - perturbed, actions, mut_predictor_input = self._perturb(samples_to_perturb) - - if not self.cfg.use_random_network: - self._data_for_grad_update.append({ - "batch_index": population.current_batch_index + 1, - "mut_predictor_input": mut_predictor_input, - "actions": actions, - "parent_rewards": parent_rewards, - }) - - return np.asarray(perturbed) - - def _perturb(self, parents, mode="train"): - length = parents.shape[1] - assert length == self.domain.length - - parents_one_hot = one_hot(parents, self.domain.vocab_size) - - output = self._mut_predictor(parents_one_hot) - pos_log_probs, all_mut_log_probs = split_mutation_predictor_output(output) - - self.rng, subrng = jax.random.split(self.rng) - positions = sample_log_probs_top_k( - pos_log_probs, - subrng, - k=self.cfg.num_mutations, - temperature=self.cfg.temperature) + self._data_for_grad_update.append( + { + "batch_index": population.current_batch_index + 1, + "mut_predictor_input": mut_predictor_input, + "actions": actions, + "parent_rewards": parent_rewards, + } + ) + + return np.asarray(perturbed) + + def _perturb(self, parents, mode="train"): + length = parents.shape[1] + assert length == self.domain.length + + parents_one_hot = one_hot(parents, self.domain.vocab_size) + + output = self._mut_predictor(parents_one_hot) + pos_log_probs, all_mut_log_probs = split_mutation_predictor_output(output) + + self.rng, subrng = jax.random.split(self.rng) + positions = sample_log_probs_top_k( + pos_log_probs, + subrng, + k=self.cfg.num_mutations, + temperature=self.cfg.temperature, + ) - pos_masks = one_hot(positions, length) + pos_masks = one_hot(positions, length) - mutation_logprobs = gather_positions(positions, all_mut_log_probs) - assert mutation_logprobs.shape == (output.shape[0], self.cfg.num_mutations, - output.shape[-1] - 1) - - self.rng, subrng = jax.random.split(self.rng) - mutation_types = gumbel_max_sampler( - mutation_logprobs, self.cfg.temperature, subrng) - - states = apply_mutations(parents, mutation_types, pos_masks, self.mutations, - use_assignment_mutations=self.use_assignment_mut) - # states shape: [num_mutations+1, batch, str_length] - # states[0] are original samples with no mutations - # states[-1] are strings with all mutations applied to them - states_oh = one_hot(states, self.domain.vocab_size) - # states_oh shape: [n_mutations+1, batch, str_length, vocab_size] - perturbed = states[-1] - - return perturbed, (mutation_types, positions), states_oh - - def _update_params(self, population): - if not self._data_for_grad_update: - return - - dat = self._data_for_grad_update.pop() - assert dat["batch_index"] == population.current_batch_index - - child_rewards = jnp.array(population.get_last_batch().rewards) - parent_rewards = dat["parent_rewards"] - - all_states = dat["mut_predictor_input"] - # all_states shape: [num_mutations, batch_size, str_length, vocab_size] - - # TODO(rubanova): rescale the rewards - terminal_rewards = child_rewards - - if self.cfg.use_actor_critic: - # Update the value function - # Compute the difference between predicted value of intermediate states - # and the final reward. - self._value_predictor.update_step( - rewards=terminal_rewards, - inputs=all_states[:-1], - actions=None, - ) - advantage = compute_advantage(self._value_predictor.params, - self._value_predictor.run_model, - terminal_rewards, all_states[:-1]) - else: - advantage = child_rewards - parent_rewards - advantage = jnp.repeat(advantage[:, None], self.cfg.num_mutations, 1) + mutation_logprobs = gather_positions(positions, all_mut_log_probs) + assert mutation_logprobs.shape == ( + output.shape[0], + self.cfg.num_mutations, + output.shape[-1] - 1, + ) - advantage = jax.lax.stop_gradient(advantage) + self.rng, subrng = jax.random.split(self.rng) + mutation_types = gumbel_max_sampler( + mutation_logprobs, self.cfg.temperature, subrng + ) - # Perform policy update. - # Compute policy on the original samples, like in _perturb function. - self._mut_predictor.update_step( - rewards=advantage, - inputs=all_states[0], - actions=dat["actions"]) + states = apply_mutations( + parents, + mutation_types, + pos_masks, + self.mutations, + use_assignment_mutations=self.use_assignment_mut, + ) + # states shape: [num_mutations+1, batch, str_length] + # states[0] are original samples with no mutations + # states[-1] are strings with all mutations applied to them + states_oh = one_hot(states, self.domain.vocab_size) + # states_oh shape: [n_mutations+1, batch, str_length, vocab_size] + perturbed = states[-1] + + return perturbed, (mutation_types, positions), states_oh + + def _update_params(self, population): + if not self._data_for_grad_update: + return + + dat = self._data_for_grad_update.pop() + assert dat["batch_index"] == population.current_batch_index + + child_rewards = jnp.array(population.get_last_batch().rewards) + parent_rewards = dat["parent_rewards"] + + all_states = dat["mut_predictor_input"] + # all_states shape: [num_mutations, batch_size, str_length, vocab_size] + + # TODO(rubanova): rescale the rewards + terminal_rewards = child_rewards + + if self.cfg.use_actor_critic: + # Update the value function + # Compute the difference between predicted value of intermediate states + # and the final reward. + self._value_predictor.update_step( + rewards=terminal_rewards, + inputs=all_states[:-1], + actions=None, + ) + advantage = compute_advantage( + self._value_predictor.params, + self._value_predictor.run_model, + terminal_rewards, + all_states[:-1], + ) + else: + advantage = child_rewards - parent_rewards + advantage = jnp.repeat(advantage[:, None], self.cfg.num_mutations, 1) + + advantage = jax.lax.stop_gradient(advantage) + + # Perform policy update. + # Compute policy on the original samples, like in _perturb function. + self._mut_predictor.update_step( + rewards=advantage, inputs=all_states[0], actions=dat["actions"] + ) - del all_states, advantage + del all_states, advantage - @property - def trained_model(self): - return (self._mut_predictor, self._value_predictor) \ No newline at end of file + @property + def trained_model(self): + return (self._mut_predictor, self._value_predictor) diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/domains.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/domains.py index d271de1..30c9599 100644 --- a/src/poli_baselines/solvers/bayesian_optimization/amortized/domains.py +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/domains.py @@ -26,247 +26,248 @@ from poli_baselines.solvers.bayesian_optimization.amortized import utils -BOS_TOKEN = '<' # Beginning of sequence token. -EOS_TOKEN = '>' # End of sequence token. -PAD_TOKEN = '_' # End of sequence token. -MASK_TOKEN = '*' # End of sequence token. -SEP_TOKEN = '|' # A special token for separating tokens for serialization. +BOS_TOKEN = "<" # Beginning of sequence token. +EOS_TOKEN = ">" # End of sequence token. +PAD_TOKEN = "_" # End of sequence token. +MASK_TOKEN = "*" # End of sequence token. +SEP_TOKEN = "|" # A special token for separating tokens for serialization. @gin.configurable class Vocabulary(object): - """Basic vocabulary used to represent output tokens for domains.""" - - def __init__(self, - tokens, - include_bos=False, - include_eos=False, - include_pad=False, - include_mask=False, - bos_token=BOS_TOKEN, - eos_token=EOS_TOKEN, - pad_token=PAD_TOKEN, - mask_token=MASK_TOKEN): - """A token vocabulary. - - Args: - tokens: An list of tokens to put in the vocab. If an int, will be - interpreted as the number of tokens and '0', ..., 'tokens-1' will be - used as tokens. - include_bos: Whether to append `bos_token` to `tokens` that marks the - beginning of a sequence. - include_eos: Whether to append `eos_token` to `tokens` that marks the - end of a sequence. - include_pad: Whether to append `pad_token` to `tokens` to marks past end - of sequence. - include_mask: Whether to append `mask_token` to `tokens` to mark masked - positions. - bos_token: A special token than marks the beginning of sequence. - Ignored if `include_bos == False`. - eos_token: A special token than marks the end of sequence. - Ignored if `include_eos == False`. - pad_token: A special token than marks past the end of sequence. - Ignored if `include_pad == False`. - mask_token: A special token than marks MASKED positions for e.g. BERT. - Ignored if `include_mask == False`. - """ - if not isinstance(tokens, Iterable): - tokens = range(tokens) - tokens = [str(token) for token in tokens] - if include_bos: - tokens.append(bos_token) - if include_eos: - tokens.append(eos_token) - if include_pad: - tokens.append(pad_token) - if include_mask: - tokens.append(mask_token) - if len(set(tokens)) != len(tokens): - raise ValueError('tokens not unique!') - special_tokens = sorted(set(tokens) & set([SEP_TOKEN])) - if special_tokens: - raise ValueError( - f'tokens contains reserved special tokens: {special_tokens}!') - - self._tokens = tokens - self._token_ids = list(range(len(self._tokens))) - self._id_to_token = collections.OrderedDict( - zip(self._token_ids, self._tokens)) - self._token_to_id = collections.OrderedDict( - zip(self._tokens, self._token_ids)) - self._bos_token = bos_token if include_bos else None - self._eos_token = eos_token if include_eos else None - self._mask_token = mask_token if include_mask else None - self._pad_token = pad_token if include_pad else None - - def __len__(self): - return len(self._tokens) - - @property - def tokens(self): - """Return the tokens of the vocabulary.""" - return list(self._tokens) - - @property - def token_ids(self): - """Return the tokens ids of the vocabulary.""" - return list(self._token_ids) - - @property - def bos(self): - """Returns the index of the BOS token or None if unspecified.""" - return (None if self._bos_token is None else - self._token_to_id[self._bos_token]) - - @property - def eos(self): - """Returns the index of the EOS token or None if unspecified.""" - return (None if self._eos_token is None else - self._token_to_id[self._eos_token]) - - @property - def mask(self): - """Returns the index of the MASK token or None if unspecified.""" - return (None if self._mask_token is None else - self._token_to_id[self._mask_token]) - - @property - def pad(self): - """Returns the index of the PAD token or None if unspecified.""" - return (None - if self._pad_token is None else self._token_to_id[self._pad_token]) - - def is_valid(self, value): - """Tests if a value is a valid token id and returns a bool.""" - return value in self._token_ids - - def are_valid(self, values): - """Tests if values are valid token ids and returns an array of bools.""" - return np.array([self.is_valid(value) for value in values]) - - def encode(self, tokens): - """Maps an iterable of string tokens to a list of integer token ids.""" - if six.PY3 and isinstance(tokens, bytes): - # Always use Unicode in Python 3. - tokens = tokens.decode('utf-8') - return [self._token_to_id[token] for token in tokens] - - def decode(self, values, stop_at_eos=False, as_str=True): - """Maps an iterable of integer token ids to string tokens. - - Args: - values: An iterable of token ids. - stop_at_eos: Whether to ignore all values after the first EOS token id. - as_str: Whether to return a list of tokens or a concatenated string. - - Returns: - A string of tokens or a list of tokens if `as_str == False`. - """ - if stop_at_eos and self.eos is None: - raise ValueError('EOS unspecified!') - tokens = [] - for value in values: - value = int(value) # Requires if value is a scalar tensor. - if stop_at_eos and value == self.eos: - break - tokens.append(self._id_to_token[value]) - return ''.join(tokens) if as_str else tokens + """Basic vocabulary used to represent output tokens for domains.""" + + def __init__( + self, + tokens, + include_bos=False, + include_eos=False, + include_pad=False, + include_mask=False, + bos_token=BOS_TOKEN, + eos_token=EOS_TOKEN, + pad_token=PAD_TOKEN, + mask_token=MASK_TOKEN, + ): + """A token vocabulary. + + Args: + tokens: An list of tokens to put in the vocab. If an int, will be + interpreted as the number of tokens and '0', ..., 'tokens-1' will be + used as tokens. + include_bos: Whether to append `bos_token` to `tokens` that marks the + beginning of a sequence. + include_eos: Whether to append `eos_token` to `tokens` that marks the + end of a sequence. + include_pad: Whether to append `pad_token` to `tokens` to marks past end + of sequence. + include_mask: Whether to append `mask_token` to `tokens` to mark masked + positions. + bos_token: A special token than marks the beginning of sequence. + Ignored if `include_bos == False`. + eos_token: A special token than marks the end of sequence. + Ignored if `include_eos == False`. + pad_token: A special token than marks past the end of sequence. + Ignored if `include_pad == False`. + mask_token: A special token than marks MASKED positions for e.g. BERT. + Ignored if `include_mask == False`. + """ + if not isinstance(tokens, Iterable): + tokens = range(tokens) + tokens = [str(token) for token in tokens] + if include_bos: + tokens.append(bos_token) + if include_eos: + tokens.append(eos_token) + if include_pad: + tokens.append(pad_token) + if include_mask: + tokens.append(mask_token) + if len(set(tokens)) != len(tokens): + raise ValueError("tokens not unique!") + special_tokens = sorted(set(tokens) & set([SEP_TOKEN])) + if special_tokens: + raise ValueError( + f"tokens contains reserved special tokens: {special_tokens}!" + ) + + self._tokens = tokens + self._token_ids = list(range(len(self._tokens))) + self._id_to_token = collections.OrderedDict(zip(self._token_ids, self._tokens)) + self._token_to_id = collections.OrderedDict(zip(self._tokens, self._token_ids)) + self._bos_token = bos_token if include_bos else None + self._eos_token = eos_token if include_eos else None + self._mask_token = mask_token if include_mask else None + self._pad_token = pad_token if include_pad else None + + def __len__(self): + return len(self._tokens) + + @property + def tokens(self): + """Return the tokens of the vocabulary.""" + return list(self._tokens) + + @property + def token_ids(self): + """Return the tokens ids of the vocabulary.""" + return list(self._token_ids) + + @property + def bos(self): + """Returns the index of the BOS token or None if unspecified.""" + return None if self._bos_token is None else self._token_to_id[self._bos_token] + + @property + def eos(self): + """Returns the index of the EOS token or None if unspecified.""" + return None if self._eos_token is None else self._token_to_id[self._eos_token] + + @property + def mask(self): + """Returns the index of the MASK token or None if unspecified.""" + return None if self._mask_token is None else self._token_to_id[self._mask_token] + + @property + def pad(self): + """Returns the index of the PAD token or None if unspecified.""" + return None if self._pad_token is None else self._token_to_id[self._pad_token] + + def is_valid(self, value): + """Tests if a value is a valid token id and returns a bool.""" + return value in self._token_ids + + def are_valid(self, values): + """Tests if values are valid token ids and returns an array of bools.""" + return np.array([self.is_valid(value) for value in values]) + + def encode(self, tokens): + """Maps an iterable of string tokens to a list of integer token ids.""" + if six.PY3 and isinstance(tokens, bytes): + # Always use Unicode in Python 3. + tokens = tokens.decode("utf-8") + return [self._token_to_id[token] for token in tokens] + + def decode(self, values, stop_at_eos=False, as_str=True): + """Maps an iterable of integer token ids to string tokens. + + Args: + values: An iterable of token ids. + stop_at_eos: Whether to ignore all values after the first EOS token id. + as_str: Whether to return a list of tokens or a concatenated string. + + Returns: + A string of tokens or a list of tokens if `as_str == False`. + """ + if stop_at_eos and self.eos is None: + raise ValueError("EOS unspecified!") + tokens = [] + for value in values: + value = int(value) # Requires if value is a scalar tensor. + if stop_at_eos and value == self.eos: + break + tokens.append(self._id_to_token[value]) + return "".join(tokens) if as_str else tokens @six.add_metaclass(abc.ABCMeta) class Domain(object): - """Base class of problem domains, which specifies the set of valid objects.""" + """Base class of problem domains, which specifies the set of valid objects.""" - @property - def mask_fn(self): - """Returns a masking function or None.""" + @property + def mask_fn(self): + """Returns a masking function or None.""" - @abc.abstractmethod - def is_valid(self, sample): - """Tests if the given sample is valid for this domain.""" + @abc.abstractmethod + def is_valid(self, sample): + """Tests if the given sample is valid for this domain.""" - def are_valid(self, samples): - """Tests if the given samples are valid for this domain.""" - return np.array([self.is_valid(sample) for sample in samples]) + def are_valid(self, samples): + """Tests if the given samples are valid for this domain.""" + return np.array([self.is_valid(sample) for sample in samples]) class DiscreteDomain(Domain): - """Base class for discrete domains: sequences of categorical variables.""" + """Base class for discrete domains: sequences of categorical variables.""" - def __init__(self, vocab): - self._vocab = vocab + def __init__(self, vocab): + self._vocab = vocab - @property - def vocab_size(self): - return len(self.vocab) + @property + def vocab_size(self): + return len(self.vocab) - @property - def vocab(self): - return self._vocab # pytype: disable=attribute-error # trace-all-classes + @property + def vocab(self): + return self._vocab # pytype: disable=attribute-error # trace-all-classes - def encode(self, samples, **kwargs): - """Maps a list of string tokens to a list of lists of integer token ids.""" - return [self.vocab.encode(sample, **kwargs) for sample in samples] + def encode(self, samples, **kwargs): + """Maps a list of string tokens to a list of lists of integer token ids.""" + return [self.vocab.encode(sample, **kwargs) for sample in samples] - def decode(self, samples, **kwargs): - """Maps list of lists of integer token ids to list of strings.""" - return [self.vocab.decode(sample, **kwargs) for sample in samples] + def decode(self, samples, **kwargs): + """Maps list of lists of integer token ids to list of strings.""" + return [self.vocab.decode(sample, **kwargs) for sample in samples] @gin.configurable class FixedLengthDiscreteDomain(DiscreteDomain): - """Output is a fixed length discrete sequence.""" - - def __init__(self, vocab_size=None, length=None, vocab=None): - """Creates an instance of this class. - - Args: - vocab_size: An optional integer for constructing a vocab of this size. - If provided, `vocab` must be `None`. - length: The length of the domain (required). - vocab: The `Vocabulary` of the domain. If provided, `vocab_size` must be - `None`. - - Raises: - ValueError: If neither `vocab_size` nor `vocab` is provided. - ValueError: If `length` if not provided. - """ - if length is None: - raise ValueError('length must be provided!') - if not (vocab_size is None) ^ (vocab is None): - raise ValueError('Exactly one of vocab_size of vocab must be specified!') - self._length = length - if vocab is None: - vocab = Vocabulary(vocab_size) - super(FixedLengthDiscreteDomain, self).__init__(vocab) - - @property - def length(self): - return self._length - - @property - def size(self): - """The number of structures in the Domain.""" - return self.vocab_size**self.length - - def is_valid(self, sequence): - return len(sequence) == self.length and self.vocab.are_valid(sequence).all() - - def sample_uniformly(self, num_samples, seed=None): - random_state = utils.get_random_state(seed) - return np.int32( - random_state.randint( - size=[num_samples, self.length], low=0, high=self.vocab_size)) - - def index_to_structure(self, index): - """Given an integer and target length, encode into structure.""" - structure = np.zeros(self.length, dtype=np.int32) - tokens = [int(token, base=len(self.vocab)) - for token in np.base_repr(index, base=len(self.vocab))] - structure[-len(tokens):] = tokens - return structure - - def structure_to_index(self, structure): - """Returns the index of a sequence over a vocabulary of size `vocab_size`.""" - structure = np.asarray(structure)[::-1] - return np.sum(structure * np.power(len(self.vocab), range(len(structure)))) \ No newline at end of file + """Output is a fixed length discrete sequence.""" + + def __init__(self, vocab_size=None, length=None, vocab=None): + """Creates an instance of this class. + + Args: + vocab_size: An optional integer for constructing a vocab of this size. + If provided, `vocab` must be `None`. + length: The length of the domain (required). + vocab: The `Vocabulary` of the domain. If provided, `vocab_size` must be + `None`. + + Raises: + ValueError: If neither `vocab_size` nor `vocab` is provided. + ValueError: If `length` if not provided. + """ + if length is None: + raise ValueError("length must be provided!") + if not (vocab_size is None) ^ (vocab is None): + raise ValueError("Exactly one of vocab_size of vocab must be specified!") + self._length = length + if vocab is None: + vocab = Vocabulary(vocab_size) + super(FixedLengthDiscreteDomain, self).__init__(vocab) + + @property + def length(self): + return self._length + + @property + def size(self): + """The number of structures in the Domain.""" + return self.vocab_size**self.length + + def is_valid(self, sequence): + return len(sequence) == self.length and self.vocab.are_valid(sequence).all() + + def sample_uniformly(self, num_samples, seed=None): + random_state = utils.get_random_state(seed) + return np.int32( + random_state.randint( + size=[num_samples, self.length], low=0, high=self.vocab_size + ) + ) + + def index_to_structure(self, index): + """Given an integer and target length, encode into structure.""" + structure = np.zeros(self.length, dtype=np.int32) + tokens = [ + int(token, base=len(self.vocab)) + for token in np.base_repr(index, base=len(self.vocab)) + ] + structure[-len(tokens) :] = tokens + return structure + + def structure_to_index(self, structure): + """Returns the index of a sequence over a vocabulary of size `vocab_size`.""" + structure = np.asarray(structure)[::-1] + return np.sum(structure * np.power(len(self.vocab), range(len(structure)))) diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/utils.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/utils.py index 22a29a7..af76f4a 100644 --- a/src/poli_baselines/solvers/bayesian_optimization/amortized/utils.py +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/utils.py @@ -38,327 +38,344 @@ # TODO(ddohan): FrozenConfig type class Config(dict): - """a dictionary that supports dot and dict notation. + """a dictionary that supports dot and dict notation. - Create: - d = Config() - d = Config({'val1':'first'}) + Create: + d = Config() + d = Config({'val1':'first'}) - Get: - d.val2 - d['val2'] + Get: + d.val2 + d['val2'] - Set: - d.val2 = 'second' - d['val2'] = 'second' - """ - __getattr__ = dict.__getitem__ - __setattr__ = dict.__setitem__ - __delattr__ = dict.__delitem__ + Set: + d.val2 = 'second' + d['val2'] = 'second' + """ - def __str__(self): - return pprint.pformat(self) + __getattr__ = dict.__getitem__ + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ - def __deepcopy__(self, memo): - return self.__class__([(copy.deepcopy(k, memo), copy.deepcopy(v, memo)) - for k, v in self.items()]) + def __str__(self): + return pprint.pformat(self) + def __deepcopy__(self, memo): + return self.__class__( + [(copy.deepcopy(k, memo), copy.deepcopy(v, memo)) for k, v in self.items()] + ) -def get_logger(name='', with_absl=True, level=logging.INFO): - """Creates a logger.""" - logger = std_logging.getLogger(name) - if with_absl: - logger.addHandler(logging.get_absl_handler()) - logger.propagate = False - logger.setLevel(level) - return logger + +def get_logger(name="", with_absl=True, level=logging.INFO): + """Creates a logger.""" + logger = std_logging.getLogger(name) + if with_absl: + logger.addHandler(logging.get_absl_handler()) + logger.propagate = False + logger.setLevel(level) + return logger def get_random_state(seed_or_state): - """Returns a np.random.RandomState given an integer seed or RandomState.""" - if isinstance(seed_or_state, int): - return np.random.RandomState(seed_or_state) - elif seed_or_state is None: - # This returns the current global np random state. - return np.random.random.__self__ - elif not isinstance(seed_or_state, np.random.RandomState): - raise ValueError('Numpy RandomState or integer seed expected! Got: %s' % - seed_or_state) - else: - return seed_or_state + """Returns a np.random.RandomState given an integer seed or RandomState.""" + if isinstance(seed_or_state, int): + return np.random.RandomState(seed_or_state) + elif seed_or_state is None: + # This returns the current global np random state. + return np.random.random.__self__ + elif not isinstance(seed_or_state, np.random.RandomState): + raise ValueError( + "Numpy RandomState or integer seed expected! Got: %s" % seed_or_state + ) + else: + return seed_or_state def set_seed(seed): - """Sets global Numpy, Tensorboard, and random seed.""" - np.random.seed(seed) - tf.set_random_seed(seed) - random.seed(seed, version=1) + """Sets global Numpy, Tensorboard, and random seed.""" + np.random.seed(seed) + tf.set_random_seed(seed) + random.seed(seed, version=1) def to_list(values, none_to_list=True): - """Converts `values` of any type to a `list`.""" - if (hasattr(values, '__iter__') and not isinstance(values, str) and - not isinstance(values, dict)): - return list(values) - elif none_to_list and values is None: - return [] - else: - return [values] + """Converts `values` of any type to a `list`.""" + if ( + hasattr(values, "__iter__") + and not isinstance(values, str) + and not isinstance(values, dict) + ): + return list(values) + elif none_to_list and values is None: + return [] + else: + return [values] def to_array(values): - """Converts input values to a np.ndarray.""" - if tf.executing_eagerly() and tf.is_tensor(values): - return values.numpy() - else: - return np.asarray(values) + """Converts input values to a np.ndarray.""" + if tf.executing_eagerly() and tf.is_tensor(values): + return values.numpy() + else: + return np.asarray(values) def arrays_from_dataset(dataset): - """Converts a tf.data.Dataset to nested np.ndarrays.""" - return tf.nest.map_structure( - lambda tensor: np.asarray(tensor), # pylint: disable=unnecessary-lambda - tensors_from_dataset(dataset)) + """Converts a tf.data.Dataset to nested np.ndarrays.""" + return tf.nest.map_structure( + lambda tensor: np.asarray(tensor), # pylint: disable=unnecessary-lambda + tensors_from_dataset(dataset), + ) def dataset_from_tensors(tensors): - """Converts nested tf.Tensors or np.ndarrays to a tf.Data.Dataset.""" - if isinstance(tensors, types.GeneratorType) or isinstance(tensors, list): - tensors = tuple(tensors) - return tf.data.Dataset.from_tensor_slices(tensors) + """Converts nested tf.Tensors or np.ndarrays to a tf.Data.Dataset.""" + if isinstance(tensors, types.GeneratorType) or isinstance(tensors, list): + tensors = tuple(tensors) + return tf.data.Dataset.from_tensor_slices(tensors) def random_choice(values, size=None, random_state=None, **kwargs): - """Enables safer sampling from a list of values than `np.random.choice`. + """Enables safer sampling from a list of values than `np.random.choice`. - `np.random.choice` fails when trying to sample, e.g., from a list of + `np.random.choice` fails when trying to sample, e.g., from a list of - indices instead of sampling from `values` directly. + indices instead of sampling from `values` directly. - Args: - values: An iterable of values. - size: The sample size. - random_state: An integer seed or a `np.random.RandomState`. - **kwargs: Named arguments passed to `np.random.choice`. + Args: + values: An iterable of values. + size: The sample size. + random_state: An integer seed or a `np.random.RandomState`. + **kwargs: Named arguments passed to `np.random.choice`. - Returns: - As single element from `values` if `size is None`, otherwise a list of - samples from `values` of length `size`. - """ - random_state = get_random_state(random_state) - values = list(values) - effective_size = 1 if size is None else size - idxs = random_state.choice(range(len(values)), size=effective_size, **kwargs) - samples = [values[idx] for idx in idxs] - return samples[0] if size is None else samples + Returns: + As single element from `values` if `size is None`, otherwise a list of + samples from `values` of length `size`. + """ + random_state = get_random_state(random_state) + values = list(values) + effective_size = 1 if size is None else size + idxs = random_state.choice(range(len(values)), size=effective_size, **kwargs) + samples = [values[idx] for idx in idxs] + return samples[0] if size is None else samples def random_shuffle(values, random_state=None): - """Shuffles a list of `values` out-of-place.""" - return random_choice( - values, size=len(values), replace=False, random_state=random_state) + """Shuffles a list of `values` out-of-place.""" + return random_choice( + values, size=len(values), replace=False, random_state=random_state + ) def get_tokens(sequences, lower=False, upper=False): - """Returns a sorted list of all unique characters of a list of sequences. + """Returns a sorted list of all unique characters of a list of sequences. - Args: - sequences: An iterable of string sequences. - lower: Whether to lower-case sequences before computing tokens. - upper: Whether to upper-case sequences before computing tokens. + Args: + sequences: An iterable of string sequences. + lower: Whether to lower-case sequences before computing tokens. + upper: Whether to upper-case sequences before computing tokens. - Returns: - A sorted list of all characters that appear in `sequences`. - """ - if lower and upper: - raise ValueError('lower and upper must not be specified at the same time!') - if lower: - sequences = [seq.lower() for seq in sequences] - if upper: - sequences = [seq.upper() for seq in sequences] - return sorted(set.union(*[set(seq) for seq in sequences])) + Returns: + A sorted list of all characters that appear in `sequences`. + """ + if lower and upper: + raise ValueError("lower and upper must not be specified at the same time!") + if lower: + sequences = [seq.lower() for seq in sequences] + if upper: + sequences = [seq.upper() for seq in sequences] + return sorted(set.union(*[set(seq) for seq in sequences])) def tensors_from_dataset(dataset): - """Converts a tf.data.Dataset to nested tf.Tensors.""" - tensors = list(dataset) - if tensors: - return tf.nest.map_structure(lambda *tensors: tf.stack(tensors), *tensors) - # Return empty tensors if the dataset is empty. - shapes_dtypes = zip( - tf.nest.flatten(dataset.output_shapes), - tf.nest.flatten(dataset.output_types)) - tensors = [ - tf.zeros(shape=[0] + shape.as_list(), dtype=dtype) - for shape, dtype in shapes_dtypes - ] - return tf.nest.pack_sequence_as(dataset.output_shapes, tensors) + """Converts a tf.data.Dataset to nested tf.Tensors.""" + tensors = list(dataset) + if tensors: + return tf.nest.map_structure(lambda *tensors: tf.stack(tensors), *tensors) + # Return empty tensors if the dataset is empty. + shapes_dtypes = zip( + tf.nest.flatten(dataset.output_shapes), tf.nest.flatten(dataset.output_types) + ) + tensors = [ + tf.zeros(shape=[0] + shape.as_list(), dtype=dtype) + for shape, dtype in shapes_dtypes + ] + return tf.nest.pack_sequence_as(dataset.output_shapes, tensors) def hash_structure(structure): - """Hashes a structure (n-d numpy array) of either ints or floats to a string. - - Args: - structure: A structure of ints that is castable to np.int32 (examples - include an np.int32, np.int64, an int32 eager tf.Tensor, - a list of python ints, etc.) or a structure of floats that is castable - to np.float32. Here, we say that an array is castable if it can be - converted to the target type, perhaps with some loss of precision - (e.g. float64 to float32). See np.can_cast(..., casting='same_kind'). - - Returns: - A string hash for the structure. The hash will depend on the - high-level type (int vs. float), but not the precision of such a type - (int32 vs. int64). - """ - array = np.asarray(structure) - if np.can_cast(array, np.int32, 'same_kind'): - return np.int32(array).tostring() - elif np.can_cast(array, np.float32, 'same_kind'): - return np.float32(array).tostring() - raise ValueError('%s can not be safely cast to np.int32 or ' - 'np.float32' % str(structure)) + """Hashes a structure (n-d numpy array) of either ints or floats to a string. + + Args: + structure: A structure of ints that is castable to np.int32 (examples + include an np.int32, np.int64, an int32 eager tf.Tensor, + a list of python ints, etc.) or a structure of floats that is castable + to np.float32. Here, we say that an array is castable if it can be + converted to the target type, perhaps with some loss of precision + (e.g. float64 to float32). See np.can_cast(..., casting='same_kind'). + + Returns: + A string hash for the structure. The hash will depend on the + high-level type (int vs. float), but not the precision of such a type + (int32 vs. int64). + """ + array = np.asarray(structure) + if np.can_cast(array, np.int32, "same_kind"): + return np.int32(array).tostring() + elif np.can_cast(array, np.float32, "same_kind"): + return np.float32(array).tostring() + raise ValueError( + "%s can not be safely cast to np.int32 or " "np.float32" % str(structure) + ) def create_unique_id(): - """Creates a unique hex ID.""" - return uuid.uuid1().hex + """Creates a unique hex ID.""" + return uuid.uuid1().hex def deduplicate_samples(samples, select_best=False): - """De-duplicates Samples with identical structures. + """De-duplicates Samples with identical structures. - Args: - samples: An iterable of `data.Sample`s. - select_best: Whether to select the sample with the highest reward among - samples with the same structure. Otherwise, the sample that occurs - first will be selected. + Args: + samples: An iterable of `data.Sample`s. + select_best: Whether to select the sample with the highest reward among + samples with the same structure. Otherwise, the sample that occurs + first will be selected. - Returns: - A list of Samples. - """ + Returns: + A list of Samples. + """ - def _sort(to_sort): - return (sorted(to_sort, key=lambda sample: sample.reward, reverse=True) - if select_best else to_sort) + def _sort(to_sort): + return ( + sorted(to_sort, key=lambda sample: sample.reward, reverse=True) + if select_best + else to_sort + ) - return [_sort(group)[0] for group in group_samples(samples).values()] + return [_sort(group)[0] for group in group_samples(samples).values()] def group_samples(samples, **kwargs): - """Groups `data.Sample`s with identical structures using `group_by_hash`.""" - return group_by_hash( - samples, - hash_fn=lambda sample: hash_structure(sample.structure), - **kwargs) + """Groups `data.Sample`s with identical structures using `group_by_hash`.""" + return group_by_hash( + samples, hash_fn=lambda sample: hash_structure(sample.structure), **kwargs + ) def group_by_hash(values, hash_fn=hash, store_index=False): - """Groups values by their hash value. - - Args: - values: An iterable of any objects. - hash_fn: A function that is called to compute the hash of values. - store_index: Whether to store the index of values or values in the returned - dict. - - Returns: - A `collections.OrderedDict` mapping hashes to values with that hash if - `store_index=False`, or value indices otherwise. The length of the map - corresponds to the number of unique hashes. - """ - groups = collections.OrderedDict() - for idx, value in enumerate(values): - to_store = idx if store_index else value - groups.setdefault(hash_fn(value), []).append(to_store) - return groups + """Groups values by their hash value. + + Args: + values: An iterable of any objects. + hash_fn: A function that is called to compute the hash of values. + store_index: Whether to store the index of values or values in the returned + dict. + + Returns: + A `collections.OrderedDict` mapping hashes to values with that hash if + `store_index=False`, or value indices otherwise. The length of the map + corresponds to the number of unique hashes. + """ + groups = collections.OrderedDict() + for idx, value in enumerate(values): + to_store = idx if store_index else value + groups.setdefault(hash_fn(value), []).append(to_store) + return groups def get_instance(instance_or_cls, **kwargs): - """Returns an instance given an instance or class reference. - - Enables passing both class references and class instances as (gin) configs. - - Args: - instance_or_cls: An instance of class or reference to a class. - **kwargs: Names arguments used for instantiation if `instance_or_cls` is a - class. - - Returns: - An instance of a class. - """ - if (inspect.isclass(instance_or_cls) or inspect.isfunction(instance_or_cls) or - isinstance(instance_or_cls, functools.partial)): - return instance_or_cls(**kwargs) - else: - return instance_or_cls - - -def pd_option_context(width=999, - max_colwidth=999, - max_rows=200, - float_format='{:.3g}', - **kwargs): - """Returns a Pandas context manager with changed default arguments.""" - return pd.option_context('display.width', width, - 'display.max_colwidth', max_colwidth, - 'display.max_rows', max_rows, - 'display.float_format', float_format.format, - **kwargs) + """Returns an instance given an instance or class reference. + + Enables passing both class references and class instances as (gin) configs. + + Args: + instance_or_cls: An instance of class or reference to a class. + **kwargs: Names arguments used for instantiation if `instance_or_cls` is a + class. + + Returns: + An instance of a class. + """ + if ( + inspect.isclass(instance_or_cls) + or inspect.isfunction(instance_or_cls) + or isinstance(instance_or_cls, functools.partial) + ): + return instance_or_cls(**kwargs) + else: + return instance_or_cls + + +def pd_option_context( + width=999, max_colwidth=999, max_rows=200, float_format="{:.3g}", **kwargs +): + """Returns a Pandas context manager with changed default arguments.""" + return pd.option_context( + "display.width", + width, + "display.max_colwidth", + max_colwidth, + "display.max_rows", + max_rows, + "display.float_format", + float_format.format, + **kwargs, + ) def log_pandas(df_or_series, logger=logging.info, **kwargs): - """Logs a `pd.DataFrame` or `pd.Series`.""" - with pd_option_context(**kwargs): - for row in df_or_series.to_string().splitlines(): - logger(row) - - -def get_indices(valid_indices, - selection, - map_negative=True, - validate=False, - exclude=False): - """Maps a `selection` to `valid_indices` for indexing an iterable. - - Supports selecting indices as by - - a scalar: it[i] - - a list of scalars: it[[i, j, k]] - - a (list of) negative scalars: it[[i, -j, -k]] - - slices: it[slice(-3, None)] - - Args: - valid_indices: An iterable of valid indices that can be selected. - selection: A scalar, list of scalars, or `slice` for selecting indices. - map_negative: Whether to interpret `-i` as `len(valid_indices) + i`. - validate: Whether to raise an `IndexError` if `selection` is not - contained in `valid_indices`. - exclude: Whether to return all `valid_indices` except the selected ones. - - Raises: - IndexError: If `validate == True` and `selection` contains an index that is - not contained in `valid_indices`. - - Returns: - A list of indices. - """ - def _raise_index_error(idx): - raise IndexError(f'Index {idx} invalid! Valid indices: {valid_indices}') - - if isinstance(selection, slice): - idxs = valid_indices[selection] - else: - idxs = [] - for idx in to_list(selection): - if map_negative and isinstance(idx, int) and idx < 0: - if abs(idx) <= len(valid_indices): - idxs.append(valid_indices[idx]) - elif validate: - _raise_index_error(idx) - elif idx in valid_indices: - idxs.append(idx) - elif validate: - _raise_index_error(idx) - if exclude: - idxs = [idx for idx in valid_indices if idx not in idxs] - return idxs \ No newline at end of file + """Logs a `pd.DataFrame` or `pd.Series`.""" + with pd_option_context(**kwargs): + for row in df_or_series.to_string().splitlines(): + logger(row) + + +def get_indices( + valid_indices, selection, map_negative=True, validate=False, exclude=False +): + """Maps a `selection` to `valid_indices` for indexing an iterable. + + Supports selecting indices as by + - a scalar: it[i] + - a list of scalars: it[[i, j, k]] + - a (list of) negative scalars: it[[i, -j, -k]] + - slices: it[slice(-3, None)] + + Args: + valid_indices: An iterable of valid indices that can be selected. + selection: A scalar, list of scalars, or `slice` for selecting indices. + map_negative: Whether to interpret `-i` as `len(valid_indices) + i`. + validate: Whether to raise an `IndexError` if `selection` is not + contained in `valid_indices`. + exclude: Whether to return all `valid_indices` except the selected ones. + + Raises: + IndexError: If `validate == True` and `selection` contains an index that is + not contained in `valid_indices`. + + Returns: + A list of indices. + """ + + def _raise_index_error(idx): + raise IndexError(f"Index {idx} invalid! Valid indices: {valid_indices}") + + if isinstance(selection, slice): + idxs = valid_indices[selection] + else: + idxs = [] + for idx in to_list(selection): + if map_negative and isinstance(idx, int) and idx < 0: + if abs(idx) <= len(valid_indices): + idxs.append(valid_indices[idx]) + elif validate: + _raise_index_error(idx) + elif idx in valid_indices: + idxs.append(idx) + elif validate: + _raise_index_error(idx) + if exclude: + idxs = [idx for idx in valid_indices if idx not in idxs] + return idxs diff --git a/src/poli_baselines/solvers/bayesian_optimization/turbo/turbo_wrapper.py b/src/poli_baselines/solvers/bayesian_optimization/turbo/turbo_wrapper.py index 2f85b9d..6855f75 100644 --- a/src/poli_baselines/solvers/bayesian_optimization/turbo/turbo_wrapper.py +++ b/src/poli_baselines/solvers/bayesian_optimization/turbo/turbo_wrapper.py @@ -51,12 +51,12 @@ def __init__( The first row contains the lower bounds on x, the last row contains the upper bounds. """ super().__init__(black_box, x0, y0) - - assert(x0.shape[0] > 1) - assert(bounds.shape[1] == 2) - assert(bounds.shape[0] == x0.shape[1]) - assert(np.all(bounds[:, 1] >= bounds[:, 0])) + assert x0.shape[0] > 1 + + assert bounds.shape[1] == 2 + assert bounds.shape[0] == x0.shape[1] + assert np.all(bounds[:, 1] >= bounds[:, 0]) bounds[:, 1] -= bounds[:, 0] assert x0.shape[0] > 1 diff --git a/src/poli_baselines/tests/solvers/bayesian_optimization/test_amortized_bo.py b/src/poli_baselines/tests/solvers/bayesian_optimization/test_amortized_bo.py index 1a1ddbf..ca0defd 100644 --- a/src/poli_baselines/tests/solvers/bayesian_optimization/test_amortized_bo.py +++ b/src/poli_baselines/tests/solvers/bayesian_optimization/test_amortized_bo.py @@ -3,7 +3,9 @@ import warnings import numpy as np -from poli_baselines.solvers.bayesian_optimization.amortized.amortized_bo_wrapper import AmortizedBOWrapper +from poli_baselines.solvers.bayesian_optimization.amortized.amortized_bo_wrapper import ( + AmortizedBOWrapper, +) warnings.filterwarnings("ignore") @@ -11,9 +13,7 @@ def test_amortized_bo_runs(): from poli import objective_factory - problem = objective_factory.create( - name="aloha", observer_name=None - ) + problem = objective_factory.create(name="aloha", observer_name=None) black_box, x0 = problem.black_box, problem.x0 y0 = black_box(x0) diff --git a/src/poli_baselines/tests/solvers/bayesian_optimization/test_turbo.py b/src/poli_baselines/tests/solvers/bayesian_optimization/test_turbo.py index 652a700..43db1c5 100644 --- a/src/poli_baselines/tests/solvers/bayesian_optimization/test_turbo.py +++ b/src/poli_baselines/tests/solvers/bayesian_optimization/test_turbo.py @@ -21,7 +21,9 @@ def test_turbo_runs(): x0 = np.concatenate([x0, np.random.rand(1, x0.shape[1])]) y0 = black_box(x0) - bounds = np.concatenate([-np.ones([x0.shape[1], 1]), np.ones([x0.shape[1], 1])], axis=-1) + bounds = np.concatenate( + [-np.ones([x0.shape[1], 1]), np.ones([x0.shape[1], 1])], axis=-1 + ) solver = Turbo(black_box, x0, y0, bounds=bounds) From b2b3cbf15a7b823f4e88da7d2ec8ac3f22a09614 Mon Sep 17 00:00:00 2001 From: Simon Bartels Date: Sat, 10 Aug 2024 15:24:17 +0200 Subject: [PATCH 7/9] adapts interface to changes in main and fixes a problem with unaligned sequences --- .../amortized/amortized_bo_wrapper.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/amortized_bo_wrapper.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/amortized_bo_wrapper.py index df3f539..b047aa8 100644 --- a/src/poli_baselines/solvers/bayesian_optimization/amortized/amortized_bo_wrapper.py +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/amortized_bo_wrapper.py @@ -1,7 +1,7 @@ import numpy as np from poli.core.abstract_black_box import AbstractBlackBox -from poli_baselines.core.abstract_solver import AbstractSolver +from poli_baselines.core.step_by_step_solver import StepByStepSolver from poli_baselines.solvers.bayesian_optimization.amortized.data import ( samples_from_arrays, Population, @@ -15,11 +15,15 @@ ) -class AmortizedBOWrapper(AbstractSolver): +class AmortizedBOWrapper(StepByStepSolver): def __init__(self, black_box: AbstractBlackBox, x0: np.ndarray, y0: np.ndarray): super().__init__(black_box, x0, y0) + self.problem_info = black_box.get_black_box_info() + alphabet = self.problem_info.get_alphabet() + if not self.problem_info.sequences_are_aligned(): + alphabet = alphabet + [self.problem_info.get_padding_token()] self.domain = FixedLengthDiscreteDomain( - vocab=Vocabulary(black_box.get_black_box_info().get_alphabet()), + vocab=Vocabulary(alphabet), length=x0.shape[1], ) self.solver = MutationPredictorSolver( @@ -32,4 +36,7 @@ def next_candidate(self) -> np.ndarray: structures=self.domain.encode(self.x0.tolist()), rewards=self.y0.tolist() ) x = self.solver.propose(num_samples=1, population=Population(samples)) - return np.array([c for c in self.domain.decode(x)[0]])[np.newaxis, :] + s = list(self.domain.decode(x)[0]) + if not self.problem_info.sequences_are_aligned(): + s = s + [self.problem_info.get_padding_token()] * (self.problem_info.get_max_sequence_length() - len(s)) + return np.array([s]) From f2aee31c767ed4d29201f1d02771c0317176e356 Mon Sep 17 00:00:00 2001 From: Simon Bartels Date: Sat, 10 Aug 2024 15:24:39 +0200 Subject: [PATCH 8/9] fixes a problem with odd alphabet size --- .../bayesian_optimization/amortized/deep_evolution_solver.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/deep_evolution_solver.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/deep_evolution_solver.py index dfc19fe..590e921 100644 --- a/src/poli_baselines/solvers/bayesian_optimization/amortized/deep_evolution_solver.py +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/deep_evolution_solver.py @@ -273,7 +273,9 @@ def _create_positional_encoding( # pylint: disable=invalid-name position = np.arange(0, max_len)[:, np.newaxis] div_term = np.exp(np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature)) pe[:, 0::2] = np.sin(position * div_term) - pe[:, 1::2] = np.cos(position * div_term) + # FIXME: Simon: the line below does not work for odd alphabet sizes! Get in touch with authors! + #pe[:, 1::2] = np.cos(position * div_term) + pe[:, 1::2] = np.cos(position * div_term[d_feature%2:]) pe = pe[np.newaxis, :, :] # [1, max_len, d_feature] return jnp.array(pe) # These are trainable parameters, initialized as above. From 9c65be75c253c2165a1b8eabb7fc314b0406a07e Mon Sep 17 00:00:00 2001 From: Simon Bartels Date: Sat, 10 Aug 2024 15:34:07 +0200 Subject: [PATCH 9/9] lints offending files --- .../bayesian_optimization/amortized/amortized_bo_wrapper.py | 4 +++- .../bayesian_optimization/amortized/deep_evolution_solver.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/amortized_bo_wrapper.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/amortized_bo_wrapper.py index b047aa8..985d852 100644 --- a/src/poli_baselines/solvers/bayesian_optimization/amortized/amortized_bo_wrapper.py +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/amortized_bo_wrapper.py @@ -38,5 +38,7 @@ def next_candidate(self) -> np.ndarray: x = self.solver.propose(num_samples=1, population=Population(samples)) s = list(self.domain.decode(x)[0]) if not self.problem_info.sequences_are_aligned(): - s = s + [self.problem_info.get_padding_token()] * (self.problem_info.get_max_sequence_length() - len(s)) + s = s + [self.problem_info.get_padding_token()] * ( + self.problem_info.get_max_sequence_length() - len(s) + ) return np.array([s]) diff --git a/src/poli_baselines/solvers/bayesian_optimization/amortized/deep_evolution_solver.py b/src/poli_baselines/solvers/bayesian_optimization/amortized/deep_evolution_solver.py index 590e921..54d7ae2 100644 --- a/src/poli_baselines/solvers/bayesian_optimization/amortized/deep_evolution_solver.py +++ b/src/poli_baselines/solvers/bayesian_optimization/amortized/deep_evolution_solver.py @@ -274,8 +274,8 @@ def _create_positional_encoding( # pylint: disable=invalid-name div_term = np.exp(np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature)) pe[:, 0::2] = np.sin(position * div_term) # FIXME: Simon: the line below does not work for odd alphabet sizes! Get in touch with authors! - #pe[:, 1::2] = np.cos(position * div_term) - pe[:, 1::2] = np.cos(position * div_term[d_feature%2:]) + # pe[:, 1::2] = np.cos(position * div_term) + pe[:, 1::2] = np.cos(position * div_term[d_feature % 2 :]) pe = pe[np.newaxis, :, :] # [1, max_len, d_feature] return jnp.array(pe) # These are trainable parameters, initialized as above.