Skip to content

Commit b7cd619

Browse files
1b15lenasal
andcommitted
new oc jax cost functions
Co-authored-by: Lena Salfenmoser <lenasal@users.noreply.github.com>
1 parent c40688f commit b7cd619

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

neurolib/control/optimal_control/oc_jax.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ def accuracy_cost(self, output):
9292
accuracy_cost += self.weights["w_p"] * 0.5 * self.model.params.dt * jnp.sum((output - self.target) ** 2)
9393
if self.weights["w_cc"] != 0.0:
9494
accuracy_cost += self.weights["w_cc"] * self.compute_cc_cost(output)
95+
if self.weights["w_var"] != 0.0:
96+
accuracy_cost += self.weights["w_var"] * self.compute_var_cost(output)
97+
if self.weights["w_f_osc"] != 0.0:
98+
accuracy_cost += self.weights["w_f_osc"] * self.compute_osc_fourier_cost(output)
9599
return accuracy_cost
96100

97101
def control_strength_cost(self, control):
@@ -118,6 +122,11 @@ def compute_cc_cost(self, output):
118122
cost *= -2.0 / (self.model.params.N * (self.model.params.N - 1) * self.T * self.model.params.dt)
119123
return cost
120124

125+
def compute_var_cost(self, output):
126+
return jnp.var(output, axis=(0, 1)).mean()
127+
128+
def compute_osc_fourier_cost(self, output):
129+
121130
def optimize_deterministic(self, n_max_iterations, output_every_nth=None):
122131
"""Compute the optimal control signal for noise averaging method 0.
123132

0 commit comments

Comments
 (0)