Skip to content

Commit 8f8b8ce

Browse files
committed
merging
2 parents 89af353 + 0a20382 commit 8f8b8ce

File tree

1 file changed

+88
-37
lines changed

1 file changed

+88
-37
lines changed

neurolib/control/optimal_control/oc_jax.py

Lines changed: 88 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import copy
66
from neurolib.models.jax.wc import WCModel
77
from neurolib.models.jax.wc.timeIntegration import timeIntegration_args, timeIntegration_elementwise
8-
from neurolib.optimize.autodiff.wc_optimizer import args_names
98

109
import logging
1110
from neurolib.control.optimal_control.oc import getdefaultweights
1211

12+
# TODO: introduce for all models, not just WC
1313
wc_default_control_params = ["exc_ext", "inh_ext"]
1414
wc_default_target_params = ["exc", "inh"]
1515

@@ -35,78 +35,129 @@ def hilbert_jax(signal, axis=-1):
3535
return analytic_signal
3636

3737

38-
class OcWc:
38+
class Optimize:
3939
def __init__(
4040
self,
4141
model,
42+
loss_function,
43+
param_names,
44+
target_param_names,
4245
target=None,
43-
optimizer=optax.adam(1e-3),
44-
control_params=wc_default_control_params,
45-
target_params=wc_default_target_params,
46+
init_params=None,
47+
optimizer=optax.adabelief(1e-3),
4648
):
47-
assert isinstance(control_params, (list, tuple)) and len(control_params) > 0
48-
assert isinstance(target_params, (list, tuple)) and len(target_params) > 0
49-
assert all([cp in wc_default_control_params for cp in control_params])
50-
assert all([tp in wc_default_target_params for tp in target_params])
49+
assert isinstance(param_names, (list, tuple)) and len(param_names) > 0
50+
assert isinstance(target_param_names, (list, tuple)) and len(target_param_names) > 0
51+
assert all([p in model.args_names for p in param_names])
52+
assert all([tp in model.output_vars for tp in target_param_names])
5153

5254
self.model = copy.deepcopy(model)
55+
self.loss_function = loss_function
5356
self.target = target
5457
self.optimizer = optimizer
55-
self.control_params = control_params
56-
self.target_params = target_params
57-
58-
self.weights = getdefaultweights()
58+
self.param_names = param_names
59+
self.target_param_names = target_param_names
5960

6061
args_values = timeIntegration_args(self.model.params)
61-
self.args = dict(zip(args_names, args_values))
62+
self.args = dict(zip(self.model.args_names, args_values))
6263

63-
self.loss = self.get_loss()
64-
self.compute_gradient = jax.jit(jax.grad(self.loss))
6564
self.T = len(self.args["t"])
6665
self.startind = self.model.getMaxDelay()
67-
self.control = jnp.zeros((len(control_params), self.model.params.N, self.T), dtype=float)
68-
self.opt_state = self.optimizer.init(self.control)
66+
if init_params is not None:
67+
self.params = init_params
68+
else:
69+
self.params = dict(zip(param_names, [self.args[p] for p in param_names]))
70+
self.opt_state = self.optimizer.init(self.params)
71+
72+
compute_loss = lambda params: self.loss_function(params, self.get_output(params))
73+
self.compute_loss = jax.jit(compute_loss)
74+
self.compute_gradient = jax.jit(jax.grad(self.compute_loss))
6975

7076
self.cost_history = []
7177

72-
def simulate(self, control):
78+
# TODO: allow arbitrary model, not just WC
79+
def simulate(self, params):
7380
args_local = self.args.copy()
74-
args_local.update(dict(zip(self.control_params, [c for c in control])))
75-
return timeIntegration_elementwise(**args_local)
76-
77-
def get_output(self, control):
78-
t, exc, inh, exc_ou, inh_ou = self.simulate(control)
79-
if self.target_params == ["exc", "inh"]:
80-
output = jnp.stack((exc, inh), axis=0)
81-
elif self.target_params == ["exc"]:
82-
output = exc[None, ...]
83-
elif self.target_params == ["inh"]:
84-
output = inh[None, ...]
85-
return output[:, :, self.startind :]
81+
args_local.update(params)
82+
t, exc, inh, exc_ou, inh_ou = timeIntegration_elementwise(**args_local)
83+
return {
84+
"t": t,
85+
"exc": exc,
86+
"inh": inh,
87+
"exc_ou": exc_ou,
88+
"inh_ou": inh_ou,
89+
}
90+
91+
def get_output(self, params):
92+
simulation_results = self.simulate(params)
93+
return jnp.stack([simulation_results[tp][:, self.startind :] for tp in self.target_param_names])
8694

8795
def get_loss(self):
8896
@jax.jit
89-
def loss(control):
90-
output = self.get_output(control)
91-
return self.compute_total_cost(control, output)
97+
def loss(params):
98+
output = self.get_output(params)
99+
return self.loss_function(params, output)
92100

93101
return loss
94102

103+
def optimize_deterministic(self, n_max_iterations, output_every_nth=None):
104+
loss = self.compute_loss(self.control)
105+
print(f"loss in iteration 0: %s" % (loss))
106+
if len(self.cost_history) == 0: # add only if control model has not yet been optimized
107+
self.cost_history.append(loss)
108+
109+
for i in range(1, n_max_iterations + 1):
110+
self.gradient = self.compute_gradient(self.control)
111+
updates, self.opt_state = self.optimizer.update(self.gradient, self.opt_state)
112+
self.control = optax.apply_updates(self.control, updates)
113+
114+
if output_every_nth is not None and i % output_every_nth == 0:
115+
loss = self.compute_loss(self.control)
116+
self.cost_history.append(loss)
117+
print(f"loss in iteration %s: %s" % (i, loss))
118+
119+
loss = self.compute_loss(self.control)
120+
print(f"Final loss : %s" % (loss))
121+
122+
123+
class OcWc(Optimize):
124+
def __init__(
125+
self,
126+
model,
127+
target=None,
128+
optimizer=optax.adabelief(1e-3),
129+
control_param_names=wc_default_control_params,
130+
target_param_names=wc_default_target_params,
131+
):
132+
super().__init__(
133+
model,
134+
self.compute_total_cost,
135+
control_param_names,
136+
target_param_names,
137+
target=target,
138+
init_params=None,
139+
optimizer=optimizer,
140+
)
141+
self.control = self.params
142+
self.weights = getdefaultweights()
143+
95144
def compute_total_cost(self, control, output):
96145
"""
97146
Compute the total cost as the sum of accuracy cost and control strength cost.
98147
99148
Parameters:
100-
control (jax.numpy.ndarray): Control input array of shape ((len(control_params)), N, T).
101-
output (jax.numpy.ndarray): Simulation output of shape ((len(target_params)), N, T).
149+
control (dict[str, jax.numpy.ndarray]): Dictionary of control inputs, where each entry has shape (N, T).
150+
output (jax.numpy.ndarray): Simulation output of shape ((len(target_param_names)), N, T).
102151
103152
Returns:
104153
float: The total cost.
105154
"""
106155
accuracy_cost = self.accuracy_cost(output)
107-
control_strength_cost = self.control_strength_cost(control)
156+
control_arr = jnp.array(list(control.values()))
157+
control_strength_cost = self.control_strength_cost(control_arr)
108158
return accuracy_cost + control_strength_cost
109159

160+
# TODO: move cost functions outside
110161
def accuracy_cost(self, output):
111162
accuracy_cost = 0.0
112163
if self.weights["w_p"] != 0.0:

0 commit comments

Comments
 (0)