Skip to content

Commit a52af67

Browse files
SetarehArcopybara-github
authored andcommitted
extra paretopt fixes
PiperOrigin-RevId: 567627034
1 parent cf042eb commit a52af67

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

vizier/_src/algorithms/designers/gp/acquisitions.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,50 @@ def score_with_aux(
170170
# Vizier library acquisition functions use `flax.struct`, instead of `attrs` and
171171
# a hash function, so that acquisition functions can be passed as args to JIT-ed
172172
# functions without triggering retracing when attribute values change.
173+
@struct.dataclass
174+
class MEANPENALIZED(AcquisitionFunction):
175+
"""Mean with relative region AcquisitionFunction."""
176+
177+
coefficient: float = 1.8
178+
max_lcb_thresh: float = -jnp.inf
179+
180+
def __call__(
181+
self,
182+
dist: tfd.Distribution,
183+
seed: Optional[jax.random.KeyArray] = None,
184+
) -> jax.Array:
185+
del seed
186+
ucb = dist.mean() + self.coefficient * dist.stddev()
187+
acquisition = jnp.where(
188+
(ucb >= self.max_lcb_thresh),
189+
dist.mean(),
190+
-1e12 - ucb,
191+
)
192+
return acquisition
193+
194+
195+
@struct.dataclass
196+
class STDPENALIZED(AcquisitionFunction):
197+
"""Standard Deviation with relative region AcquisitionFunction."""
198+
199+
coefficient: float = 1.8
200+
max_lcb_thresh: float = -jnp.inf
201+
202+
def __call__(
203+
self,
204+
dist: tfd.Distribution,
205+
seed: Optional[jax.random.KeyArray] = None,
206+
) -> jax.Array:
207+
del seed
208+
ucb = dist.mean() + self.coefficient * dist.stddev()
209+
acquisition = jnp.where(
210+
(ucb >= self.max_lcb_thresh),
211+
dist.stddev(),
212+
-1e12 - ucb,
213+
)
214+
return acquisition
215+
216+
173217
@struct.dataclass
174218
class UCB(AcquisitionFunction):
175219
"""UCB AcquisitionFunction."""

0 commit comments

Comments
 (0)