Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions vizier/_src/algorithms/designers/gp/acquisitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,50 @@ def score_with_aux(
# Vizier library acquisition functions use `flax.struct`, instead of `attrs` and
# a hash function, so that acquisition functions can be passed as args to JIT-ed
# functions without triggering retracing when attribute values change.
@struct.dataclass
class MEANPENALIZED(AcquisitionFunction):
"""Mean with relative region AcquisitionFunction."""

coefficient: float = 1.8
max_lcb_thresh: float = -jnp.inf

def __call__(
self,
dist: tfd.Distribution,
seed: Optional[jax.random.KeyArray] = None,
) -> jax.Array:
del seed
ucb = dist.mean() + self.coefficient * dist.stddev()
acquisition = jnp.where(
(ucb >= self.max_lcb_thresh),
dist.mean(),
-1e12 - ucb,
)
return acquisition


@struct.dataclass
class STDPENALIZED(AcquisitionFunction):
"""Standard Deviation with relative region AcquisitionFunction."""

coefficient: float = 1.8
max_lcb_thresh: float = -jnp.inf

def __call__(
self,
dist: tfd.Distribution,
seed: Optional[jax.random.KeyArray] = None,
) -> jax.Array:
del seed
ucb = dist.mean() + self.coefficient * dist.stddev()
acquisition = jnp.where(
(ucb >= self.max_lcb_thresh),
dist.stddev(),
-1e12 - ucb,
)
return acquisition


@struct.dataclass
class UCB(AcquisitionFunction):
"""UCB AcquisitionFunction."""
Expand Down