Skip to content

Commit d1f38ee

Browse files
committed
minor fixes to gradient-based optimization
1 parent 5ae2b5c commit d1f38ee

File tree

2 files changed

+27
-9
lines changed

2 files changed

+27
-9
lines changed

neurolib/control/optimal_control/oc_jax.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,17 @@ def __init__(
4242
self.args = dict(zip(self.model.args_names, args_values))
4343

4444
self.T = len(self.args["t"])
45-
self.startind = self.model.getMaxDelay()
45+
self.startind = self.model.getMaxDelay() + 1
4646
if init_params is not None:
4747
self.params = init_params
4848
else:
4949
self.params = dict(zip(param_names, [self.args[p] for p in param_names]))
5050
self.opt_state = self.optimizer.init(self.params)
5151

52-
compute_loss = lambda params: self.loss_function(self.get_output(params)) + self.regularization_function(params)
52+
# TODO: instead apply individually to each param
53+
compute_loss = lambda params: self.loss_function(
54+
jnp.stack(list(self.get_output(params).values()))
55+
) + self.regularization_function(params)
5356
self.compute_loss = jax.jit(compute_loss)
5457
self.compute_gradient = jax.jit(jax.grad(self.compute_loss))
5558

@@ -70,25 +73,25 @@ def simulate(self, params):
7073

7174
def get_output(self, params):
7275
simulation_results = self.simulate(params)
73-
return jnp.stack([simulation_results[tp][:, self.startind :] for tp in self.target_param_names])
76+
return {tp: simulation_results[tp][:, self.startind :] for tp in self.target_param_names}
7477

7578
def optimize(self, n_max_iterations, output_every_nth=None):
76-
loss = self.compute_loss(self.control)
79+
loss = self.compute_loss(self.params)
7780
print(f"loss in iteration 0: %s" % (loss))
78-
if len(self.cost_history) == 0: # add only if control model has not yet been optimized
81+
if len(self.cost_history) == 0: # add only if params have not yet been optimized
7982
self.cost_history.append(loss)
8083

8184
for i in range(1, n_max_iterations + 1):
82-
self.gradient = self.compute_gradient(self.control)
85+
self.gradient = self.compute_gradient(self.params)
8386
updates, self.opt_state = self.optimizer.update(self.gradient, self.opt_state)
84-
self.control = optax.apply_updates(self.control, updates)
87+
self.params = optax.apply_updates(self.params, updates)
8588

8689
if output_every_nth is not None and i % output_every_nth == 0:
87-
loss = self.compute_loss(self.control)
90+
loss = self.compute_loss(self.params)
8891
self.cost_history.append(loss)
8992
print(f"loss in iteration %s: %s" % (i, loss))
9093

91-
loss = self.compute_loss(self.control)
94+
loss = self.compute_loss(self.params)
9295
print(f"Final loss : %s" % (loss))
9396

9497

@@ -176,3 +179,7 @@ def control_strength_cost(self, control):
176179
def compute_ds_cost(self, control):
177180
eps = 1e-6 # avoid grad(sqrt(0.0))
178181
return jnp.sum(jnp.sqrt(jnp.sum(control**2, axis=2) * self.model.params.dt + eps))
182+
183+
def optimize(self, *args, **kwargs):
184+
super().optimize(*args, **kwargs)
185+
self.control = self.params

neurolib/optimize/loss_functions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,17 @@
22
import jax.numpy as jnp
33

44

5+
def get_l2_regularization(scale=1.0):
6+
def l2_regularization(params):
7+
"""
8+
Args:
9+
params (dict[str, jax.numpy.ndarray]): Dictionary of parameters being optimized
10+
"""
11+
return scale * jnp.linalg.norm(jnp.stack(list(params.values())))
12+
13+
return l2_regularization
14+
15+
516
def variance_loss(output):
617
"""
718
Args:

0 commit comments

Comments
 (0)