@@ -222,32 +222,3 @@ def compute_sync_fourier_cost(self, output):
222
222
def compute_kuramoto_cost (self , output ):
223
223
phase = jnp .angle (hilbert_jax (output , axis = 2 ))
224
224
return - jnp .mean (jnp .abs (jnp .mean (jnp .exp (complex (0 ,1 )* phase ), axis = 1 )))
225
-
226
-
227
- def optimize_deterministic (self , n_max_iterations , output_every_nth = None ):
228
- """Compute the optimal control signal for noise averaging method 0.
229
-
230
- :param n_max_iterations: maximum number of iterations of gradient descent
231
- :type n_max_iterations: int
232
- """
233
-
234
- output = self .get_output (self .control )
235
-
236
- cost = self .compute_total_cost (self .control , output )
237
- print (f"Cost in iteration 0: %s" % (cost ))
238
- if len (self .cost_history ) == 0 : # add only if control model has not yet been optimized
239
- self .cost_history .append (cost )
240
-
241
- for i in range (1 , n_max_iterations + 1 ):
242
- self .gradient = self .compute_gradient (self .control )
243
-
244
- updates , self .opt_state = self .optimizer .update (self .gradient , self .opt_state )
245
- self .control = optax .apply_updates (self .control , updates )
246
-
247
- output = self .get_output (self .control )
248
- if output_every_nth is not None and i % output_every_nth == 0 :
249
- cost = self .compute_total_cost (self .control , output )
250
- self .cost_history .append (cost )
251
- print (f"Cost in iteration %s: %s" % (i , cost ))
252
-
253
- print (f"Final cost : %s" % (cost ))
0 commit comments