@@ -42,14 +42,17 @@ def __init__(
42
42
self .args = dict (zip (self .model .args_names , args_values ))
43
43
44
44
self .T = len (self .args ["t" ])
45
- self .startind = self .model .getMaxDelay ()
45
+ self .startind = self .model .getMaxDelay () + 1
46
46
if init_params is not None :
47
47
self .params = init_params
48
48
else :
49
49
self .params = dict (zip (param_names , [self .args [p ] for p in param_names ]))
50
50
self .opt_state = self .optimizer .init (self .params )
51
51
52
- compute_loss = lambda params : self .loss_function (self .get_output (params )) + self .regularization_function (params )
52
+ # TODO: instead apply individually to each param
53
+ compute_loss = lambda params : self .loss_function (
54
+ jnp .stack (list (self .get_output (params ).values ()))
55
+ ) + self .regularization_function (params )
53
56
self .compute_loss = jax .jit (compute_loss )
54
57
self .compute_gradient = jax .jit (jax .grad (self .compute_loss ))
55
58
@@ -70,25 +73,25 @@ def simulate(self, params):
70
73
71
74
def get_output (self , params ):
72
75
simulation_results = self .simulate (params )
73
- return jnp . stack ([ simulation_results [tp ][:, self .startind :] for tp in self .target_param_names ])
76
+ return { tp : simulation_results [tp ][:, self .startind :] for tp in self .target_param_names }
74
77
75
78
def optimize (self , n_max_iterations , output_every_nth = None ):
76
- loss = self .compute_loss (self .control )
79
+ loss = self .compute_loss (self .params )
77
80
print (f"loss in iteration 0: %s" % (loss ))
78
- if len (self .cost_history ) == 0 : # add only if control model has not yet been optimized
81
+ if len (self .cost_history ) == 0 : # add only if params have not yet been optimized
79
82
self .cost_history .append (loss )
80
83
81
84
for i in range (1 , n_max_iterations + 1 ):
82
- self .gradient = self .compute_gradient (self .control )
85
+ self .gradient = self .compute_gradient (self .params )
83
86
updates , self .opt_state = self .optimizer .update (self .gradient , self .opt_state )
84
- self .control = optax .apply_updates (self .control , updates )
87
+ self .params = optax .apply_updates (self .params , updates )
85
88
86
89
if output_every_nth is not None and i % output_every_nth == 0 :
87
- loss = self .compute_loss (self .control )
90
+ loss = self .compute_loss (self .params )
88
91
self .cost_history .append (loss )
89
92
print (f"loss in iteration %s: %s" % (i , loss ))
90
93
91
- loss = self .compute_loss (self .control )
94
+ loss = self .compute_loss (self .params )
92
95
print (f"Final loss : %s" % (loss ))
93
96
94
97
@@ -176,3 +179,7 @@ def control_strength_cost(self, control):
176
179
def compute_ds_cost (self , control ):
177
180
eps = 1e-6 # avoid grad(sqrt(0.0))
178
181
return jnp .sum (jnp .sqrt (jnp .sum (control ** 2 , axis = 2 ) * self .model .params .dt + eps ))
182
+
183
+ def optimize (self , * args , ** kwargs ):
184
+ super ().optimize (* args , ** kwargs )
185
+ self .control = self .params
0 commit comments