@@ -170,6 +170,50 @@ def score_with_aux(
170170# Vizier library acquisition functions use `flax.struct`, instead of `attrs` and
171171# a hash function, so that acquisition functions can be passed as args to JIT-ed
172172# functions without triggering retracing when attribute values change.
173+ @struct .dataclass
174+ class MEANPENALIZED (AcquisitionFunction ):
175+ """Mean with relative region AcquisitionFunction."""
176+
177+ coefficient : float = 1.8
178+ max_lcb_thresh : float = - jnp .inf
179+
180+ def __call__ (
181+ self ,
182+ dist : tfd .Distribution ,
183+ seed : Optional [jax .random .KeyArray ] = None ,
184+ ) -> jax .Array :
185+ del seed
186+ ucb = dist .mean () + self .coefficient * dist .stddev ()
187+ acquisition = jnp .where (
188+ (ucb >= self .max_lcb_thresh ),
189+ dist .mean (),
190+ - 1e12 - ucb ,
191+ )
192+ return acquisition
193+
194+
195+ @struct .dataclass
196+ class STDPENALIZED (AcquisitionFunction ):
197+ """Standard Deviation with relative region AcquisitionFunction."""
198+
199+ coefficient : float = 1.8
200+ max_lcb_thresh : float = - jnp .inf
201+
202+ def __call__ (
203+ self ,
204+ dist : tfd .Distribution ,
205+ seed : Optional [jax .random .KeyArray ] = None ,
206+ ) -> jax .Array :
207+ del seed
208+ ucb = dist .mean () + self .coefficient * dist .stddev ()
209+ acquisition = jnp .where (
210+ (ucb >= self .max_lcb_thresh ),
211+ dist .stddev (),
212+ - 1e12 - ucb ,
213+ )
214+ return acquisition
215+
216+
173217@struct .dataclass
174218class UCB (AcquisitionFunction ):
175219 """UCB AcquisitionFunction."""
0 commit comments