5
5
import copy
6
6
from neurolib .models .jax .wc import WCModel
7
7
from neurolib .models .jax .wc .timeIntegration import timeIntegration_args , timeIntegration_elementwise
8
- from neurolib .optimize .autodiff .wc_optimizer import args_names
9
8
10
9
import logging
11
10
from neurolib .control .optimal_control .oc import getdefaultweights
12
11
12
+ # TODO: introduce for all models, not just WC
13
13
wc_default_control_params = ["exc_ext" , "inh_ext" ]
14
14
wc_default_target_params = ["exc" , "inh" ]
15
15
@@ -35,78 +35,129 @@ def hilbert_jax(signal, axis=-1):
35
35
return analytic_signal
36
36
37
37
38
- class OcWc :
38
+ class Optimize :
39
39
def __init__ (
40
40
self ,
41
41
model ,
42
+ loss_function ,
43
+ param_names ,
44
+ target_param_names ,
42
45
target = None ,
43
- optimizer = optax .adam (1e-3 ),
44
- control_params = wc_default_control_params ,
45
- target_params = wc_default_target_params ,
46
+ init_params = None ,
47
+ optimizer = optax .adabelief (1e-3 ),
46
48
):
47
- assert isinstance (control_params , (list , tuple )) and len (control_params ) > 0
48
- assert isinstance (target_params , (list , tuple )) and len (target_params ) > 0
49
- assert all ([cp in wc_default_control_params for cp in control_params ])
50
- assert all ([tp in wc_default_target_params for tp in target_params ])
49
+ assert isinstance (param_names , (list , tuple )) and len (param_names ) > 0
50
+ assert isinstance (target_param_names , (list , tuple )) and len (target_param_names ) > 0
51
+ assert all ([p in model . args_names for p in param_names ])
52
+ assert all ([tp in model . output_vars for tp in target_param_names ])
51
53
52
54
self .model = copy .deepcopy (model )
55
+ self .loss_function = loss_function
53
56
self .target = target
54
57
self .optimizer = optimizer
55
- self .control_params = control_params
56
- self .target_params = target_params
57
-
58
- self .weights = getdefaultweights ()
58
+ self .param_names = param_names
59
+ self .target_param_names = target_param_names
59
60
60
61
args_values = timeIntegration_args (self .model .params )
61
- self .args = dict (zip (args_names , args_values ))
62
+ self .args = dict (zip (self . model . args_names , args_values ))
62
63
63
- self .loss = self .get_loss ()
64
- self .compute_gradient = jax .jit (jax .grad (self .loss ))
65
64
self .T = len (self .args ["t" ])
66
65
self .startind = self .model .getMaxDelay ()
67
- self .control = jnp .zeros ((len (control_params ), self .model .params .N , self .T ), dtype = float )
68
- self .opt_state = self .optimizer .init (self .control )
66
+ if init_params is not None :
67
+ self .params = init_params
68
+ else :
69
+ self .params = dict (zip (param_names , [self .args [p ] for p in param_names ]))
70
+ self .opt_state = self .optimizer .init (self .params )
71
+
72
+ compute_loss = lambda params : self .loss_function (params , self .get_output (params ))
73
+ self .compute_loss = jax .jit (compute_loss )
74
+ self .compute_gradient = jax .jit (jax .grad (self .compute_loss ))
69
75
70
76
self .cost_history = []
71
77
72
- def simulate (self , control ):
78
+ # TODO: allow arbitrary model, not just WC
79
+ def simulate (self , params ):
73
80
args_local = self .args .copy ()
74
- args_local .update (dict (zip (self .control_params , [c for c in control ])))
75
- return timeIntegration_elementwise (** args_local )
76
-
77
- def get_output (self , control ):
78
- t , exc , inh , exc_ou , inh_ou = self .simulate (control )
79
- if self .target_params == ["exc" , "inh" ]:
80
- output = jnp .stack ((exc , inh ), axis = 0 )
81
- elif self .target_params == ["exc" ]:
82
- output = exc [None , ...]
83
- elif self .target_params == ["inh" ]:
84
- output = inh [None , ...]
85
- return output [:, :, self .startind :]
81
+ args_local .update (params )
82
+ t , exc , inh , exc_ou , inh_ou = timeIntegration_elementwise (** args_local )
83
+ return {
84
+ "t" : t ,
85
+ "exc" : exc ,
86
+ "inh" : inh ,
87
+ "exc_ou" : exc_ou ,
88
+ "inh_ou" : inh_ou ,
89
+ }
90
+
91
+ def get_output (self , params ):
92
+ simulation_results = self .simulate (params )
93
+ return jnp .stack ([simulation_results [tp ][:, self .startind :] for tp in self .target_param_names ])
86
94
87
95
def get_loss (self ):
88
96
@jax .jit
89
- def loss (control ):
90
- output = self .get_output (control )
91
- return self .compute_total_cost ( control , output )
97
+ def loss (params ):
98
+ output = self .get_output (params )
99
+ return self .loss_function ( params , output )
92
100
93
101
return loss
94
102
103
+ def optimize_deterministic (self , n_max_iterations , output_every_nth = None ):
104
+ loss = self .compute_loss (self .control )
105
+ print (f"loss in iteration 0: %s" % (loss ))
106
+ if len (self .cost_history ) == 0 : # add only if control model has not yet been optimized
107
+ self .cost_history .append (loss )
108
+
109
+ for i in range (1 , n_max_iterations + 1 ):
110
+ self .gradient = self .compute_gradient (self .control )
111
+ updates , self .opt_state = self .optimizer .update (self .gradient , self .opt_state )
112
+ self .control = optax .apply_updates (self .control , updates )
113
+
114
+ if output_every_nth is not None and i % output_every_nth == 0 :
115
+ loss = self .compute_loss (self .control )
116
+ self .cost_history .append (loss )
117
+ print (f"loss in iteration %s: %s" % (i , loss ))
118
+
119
+ loss = self .compute_loss (self .control )
120
+ print (f"Final loss : %s" % (loss ))
121
+
122
+
123
+ class OcWc (Optimize ):
124
+ def __init__ (
125
+ self ,
126
+ model ,
127
+ target = None ,
128
+ optimizer = optax .adabelief (1e-3 ),
129
+ control_param_names = wc_default_control_params ,
130
+ target_param_names = wc_default_target_params ,
131
+ ):
132
+ super ().__init__ (
133
+ model ,
134
+ self .compute_total_cost ,
135
+ control_param_names ,
136
+ target_param_names ,
137
+ target = target ,
138
+ init_params = None ,
139
+ optimizer = optimizer ,
140
+ )
141
+ self .control = self .params
142
+ self .weights = getdefaultweights ()
143
+
95
144
def compute_total_cost (self , control , output ):
96
145
"""
97
146
Compute the total cost as the sum of accuracy cost and control strength cost.
98
147
99
148
Parameters:
100
- control (jax.numpy.ndarray): Control input array of shape ((len(control_params)), N, T).
101
- output (jax.numpy.ndarray): Simulation output of shape ((len(target_params )), N, T).
149
+ control (dict[str, jax.numpy.ndarray] ): Dictionary of control inputs, where each entry has shape (N, T).
150
+ output (jax.numpy.ndarray): Simulation output of shape ((len(target_param_names )), N, T).
102
151
103
152
Returns:
104
153
float: The total cost.
105
154
"""
106
155
accuracy_cost = self .accuracy_cost (output )
107
- control_strength_cost = self .control_strength_cost (control )
156
+ control_arr = jnp .array (list (control .values ()))
157
+ control_strength_cost = self .control_strength_cost (control_arr )
108
158
return accuracy_cost + control_strength_cost
109
159
160
+ # TODO: move cost functions outside
110
161
def accuracy_cost (self , output ):
111
162
accuracy_cost = 0.0
112
163
if self .weights ["w_p" ] != 0.0 :
0 commit comments