Skip to content

Commit bbd9221

Browse files
committed
black format
1 parent 442e9d2 commit bbd9221

File tree

1 file changed

+37
-14
lines changed

1 file changed

+37
-14
lines changed

demonstrations/tutorial_eqnn_force_field_catalyst_compiled.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,15 @@
3131
sigmas = jnp.array(np.array([X, Y, Z])) # Vector of Pauli matrices
3232
sigmas_sigmas = jnp.array(
3333
np.array(
34-
[np.kron(X, X), np.kron(Y, Y), np.kron(Z, Z)] # Vector of tensor products of Pauli matrices
34+
[
35+
np.kron(X, X),
36+
np.kron(Y, Y),
37+
np.kron(Z, Z),
38+
] # Vector of tensor products of Pauli matrices
3539
)
3640
)
3741

42+
3843
def singlet(wires):
3944
# Encode a 2-qubit rotation-invariant initial state, i.e., the singlet state.
4045
qml.Hadamard(wires=wires[0])
@@ -76,7 +81,7 @@ def noise_layer(epsilon, wires):
7681
rep = 2 # Number of repeated vertical encoding
7782

7883
active_atoms = 2 # Number of active atoms
79-
# Here we only have two active atoms since we fixed the oxygen (which becomes non-active) at the origin
84+
# Here we only have two active atoms since we fixed the oxygen (which becomes non-active) at the origin
8085
num_qubits = active_atoms * rep
8186

8287

@@ -93,22 +98,25 @@ def noise_layer(epsilon, wires):
9398
# If there is no such dependence, `catalyst.for_loop` can still be used.
9499
# Here we showcase both usages.
95100

101+
96102
@qjit
97103
@qml.qnode(dev)
98104
def vqlm_qjit(data, params):
99105
weights = params["params"]["weights"]
100106
alphas = params["params"]["alphas"]
101107
epsilon = params["params"]["epsilon"]
108+
102109
# Initial state
103110
@catalyst.for_loop(0, rep, 1)
104111
def singlet_loop(i):
105-
singlet(wires=jnp.arange(active_atoms)+active_atoms*i)
112+
singlet(wires=jnp.arange(active_atoms) + active_atoms * i)
113+
106114
singlet_loop()
107115
# Initial encoding
108116
for i in range(num_qubits):
109117
equivariant_encoding(
110118
alphas[i, 0], jnp.asarray(data)[i % active_atoms, ...], wires=[i]
111-
)
119+
)
112120
# Reuploading model
113121
for d in range(D):
114122
qml.Barrier()
@@ -129,13 +137,20 @@ def singlet_loop(i):
129137
)
130138
return qml.expval(Observable)
131139

140+
132141
# vectorizing for batched training with `catalyst.vmap`
133-
vec_vqlm = catalyst.vmap(vqlm_qjit, in_axes=(0, {'params': {'alphas': None, 'epsilon': None, 'weights': None}} ), out_axes=0)
142+
vec_vqlm = catalyst.vmap(
143+
vqlm_qjit,
144+
in_axes=(0, {"params": {"alphas": None, "epsilon": None, "weights": None}}),
145+
out_axes=0,
146+
)
147+
134148

135149
# loss function for cost
136150
def mse_loss(predictions, targets):
137151
return jnp.mean(0.5 * (predictions - targets) ** 2)
138152

153+
139154
# Compile a training step
140155
# many calls so compile = faster!
141156
@qjit
@@ -149,7 +164,7 @@ def cost(weights, loss_data):
149164

150165
net_params = get_params(opt_state)
151166
loss = cost(net_params, loss_data)
152-
grads = catalyst.grad(cost, method = "fd", h=1e-13, argnums=0)(net_params, loss_data)
167+
grads = catalyst.grad(cost, method="fd", h=1e-13, argnums=0)(net_params, loss_data)
153168
return loss, opt_update(step_i, grads, opt_state)
154169

155170

@@ -188,11 +203,15 @@ def inference(loss_data, opt_state):
188203
data[:, 1, :] = positions[:, 2, :] - positions[:, 0, :]
189204
positions = data.copy()
190205

191-
forces = forces[:, 1:, :] # Select only the forces on the hydrogen atoms since the oxygen is fixed
206+
forces = forces[
207+
:, 1:, :
208+
] # Select only the forces on the hydrogen atoms since the oxygen is fixed
192209

193210

194211
# Splitting in train-test set
195-
indices_train = np.random.choice(np.arange(shape[0]), size=int(0.8 * shape[0]), replace=False)
212+
indices_train = np.random.choice(
213+
np.arange(shape[0]), size=int(0.8 * shape[0]), replace=False
214+
)
196215
indices_test = np.setdiff1d(np.arange(shape[0]), indices_train)
197216

198217
E_train, E_test = (energy[indices_train, 0], energy[indices_test, 0])
@@ -215,7 +234,9 @@ def inference(loss_data, opt_state):
215234
np.random.seed(42)
216235
epsilon = jnp.array(np.random.normal(0, 0.001, size=(D, num_qubits)))
217236
epsilon = None # We disable SB for this specific example
218-
epsilon = jax.lax.stop_gradient(epsilon) # comment if we wish to train the SB weights as well.
237+
epsilon = jax.lax.stop_gradient(
238+
epsilon
239+
) # comment if we wish to train the SB weights as well.
219240

220241

221242
opt_init, opt_update, get_params = optimizers.adam(1e-2)
@@ -224,8 +245,8 @@ def inference(loss_data, opt_state):
224245
running_loss = []
225246

226247

227-
num_batches = 5000 # number of optimization steps
228-
batch_size = 256 # number of training data per batch
248+
num_batches = 5000 # number of optimization steps
249+
batch_size = 256 # number of training data per batch
229250

230251
batch = np.random.choice(np.arange(np.shape(data_train)[0]), batch_size, replace=False)
231252
loss_data = data_train[batch, ...], E_train[batch, ...], F_train[batch, ...]
@@ -235,7 +256,9 @@ def inference(loss_data, opt_state):
235256
# We call `train_step` and `inference` many times, so the speedup from qjit will be quite significant!
236257
for ibatch in range(num_batches):
237258
# select a batch of training points
238-
batch = np.random.choice(np.arange(np.shape(data_train)[0]), batch_size, replace=False)
259+
batch = np.random.choice(
260+
np.arange(np.shape(data_train)[0]), batch_size, replace=False
261+
)
239262

240263
# preparing the data
241264
loss_data = data_train[batch, ...], E_train[batch, ...], F_train[batch, ...]
@@ -253,7 +276,7 @@ def inference(loss_data, opt_state):
253276

254277
### plotting ###
255278
fontsize = 12
256-
plt.figure(figsize=(4,4))
279+
plt.figure(figsize=(4, 4))
257280
plt.plot(history_loss[:, 0], "r-", label="training error")
258281
plt.plot(history_loss[:, 1], "b-", label="testing error")
259282

@@ -265,7 +288,7 @@ def inference(loss_data, opt_state):
265288
plt.show()
266289

267290

268-
plt.figure(figsize=(4,4))
291+
plt.figure(figsize=(4, 4))
269292
plt.title("Energy predictions", fontsize=fontsize)
270293
plt.plot(energy[indices_test], E_pred, "ro", label="Test predictions")
271294
plt.plot(energy[indices_test], energy[indices_test], "k.-", lw=1, label="Exact")

0 commit comments

Comments
 (0)