Skip to content

Commit 442e9d2

Browse files
committed
move opt_init, ... = optimizers.adam(...) into the original location
1 parent f83f690 commit 442e9d2

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

demonstrations/tutorial_eqnn_force_field_catalyst_compiled.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,6 @@ def inference(loss_data, opt_state):
203203
)
204204

205205
### training ###
206-
opt_init, opt_update, get_params = optimizers.adam(1e-2)
207-
208206
np.random.seed(42)
209207
weights = np.zeros((num_qubits, D, B))
210208
weights[0] = np.random.uniform(0, np.pi, 1)
@@ -220,7 +218,7 @@ def inference(loss_data, opt_state):
220218
epsilon = jax.lax.stop_gradient(epsilon) # comment if we wish to train the SB weights as well.
221219

222220

223-
221+
opt_init, opt_update, get_params = optimizers.adam(1e-2)
224222
net_params = {"params": {"weights": weights, "alphas": alphas, "epsilon": epsilon}}
225223
opt_state = opt_init(net_params)
226224
running_loss = []

0 commit comments

Comments
 (0)