31
31
sigmas = jnp .array (np .array ([X , Y , Z ])) # Vector of Pauli matrices
32
32
sigmas_sigmas = jnp .array (
33
33
np .array (
34
- [np .kron (X , X ), np .kron (Y , Y ), np .kron (Z , Z )] # Vector of tensor products of Pauli matrices
34
+ [
35
+ np .kron (X , X ),
36
+ np .kron (Y , Y ),
37
+ np .kron (Z , Z ),
38
+ ] # Vector of tensor products of Pauli matrices
35
39
)
36
40
)
37
41
42
+
38
43
def singlet (wires ):
39
44
# Encode a 2-qubit rotation-invariant initial state, i.e., the singlet state.
40
45
qml .Hadamard (wires = wires [0 ])
@@ -76,7 +81,7 @@ def noise_layer(epsilon, wires):
76
81
rep = 2 # Number of repeated vertical encoding
77
82
78
83
active_atoms = 2 # Number of active atoms
79
- # Here we only have two active atoms since we fixed the oxygen (which becomes non-active) at the origin
84
+ # Here we only have two active atoms since we fixed the oxygen (which becomes non-active) at the origin
80
85
num_qubits = active_atoms * rep
81
86
82
87
@@ -93,22 +98,25 @@ def noise_layer(epsilon, wires):
93
98
# If there is no such dependence, `catalyst.for_loop` can still be used.
94
99
# Here we showcase both usages.
95
100
101
+
96
102
@qjit
97
103
@qml .qnode (dev )
98
104
def vqlm_qjit (data , params ):
99
105
weights = params ["params" ]["weights" ]
100
106
alphas = params ["params" ]["alphas" ]
101
107
epsilon = params ["params" ]["epsilon" ]
108
+
102
109
# Initial state
103
110
@catalyst .for_loop (0 , rep , 1 )
104
111
def singlet_loop (i ):
105
- singlet (wires = jnp .arange (active_atoms )+ active_atoms * i )
112
+ singlet (wires = jnp .arange (active_atoms ) + active_atoms * i )
113
+
106
114
singlet_loop ()
107
115
# Initial encoding
108
116
for i in range (num_qubits ):
109
117
equivariant_encoding (
110
118
alphas [i , 0 ], jnp .asarray (data )[i % active_atoms , ...], wires = [i ]
111
- )
119
+ )
112
120
# Reuploading model
113
121
for d in range (D ):
114
122
qml .Barrier ()
@@ -129,13 +137,20 @@ def singlet_loop(i):
129
137
)
130
138
return qml .expval (Observable )
131
139
140
+
132
141
# vectorizing for batched training with `catalyst.vmap`
133
- vec_vqlm = catalyst .vmap (vqlm_qjit , in_axes = (0 , {'params' : {'alphas' : None , 'epsilon' : None , 'weights' : None }} ), out_axes = 0 )
142
+ vec_vqlm = catalyst .vmap (
143
+ vqlm_qjit ,
144
+ in_axes = (0 , {"params" : {"alphas" : None , "epsilon" : None , "weights" : None }}),
145
+ out_axes = 0 ,
146
+ )
147
+
134
148
135
149
# loss function for cost
136
150
def mse_loss (predictions , targets ):
137
151
return jnp .mean (0.5 * (predictions - targets ) ** 2 )
138
152
153
+
139
154
# Compile a training step
140
155
# many calls so compile = faster!
141
156
@qjit
@@ -149,7 +164,7 @@ def cost(weights, loss_data):
149
164
150
165
net_params = get_params (opt_state )
151
166
loss = cost (net_params , loss_data )
152
- grads = catalyst .grad (cost , method = "fd" , h = 1e-13 , argnums = 0 )(net_params , loss_data )
167
+ grads = catalyst .grad (cost , method = "fd" , h = 1e-13 , argnums = 0 )(net_params , loss_data )
153
168
return loss , opt_update (step_i , grads , opt_state )
154
169
155
170
@@ -188,11 +203,15 @@ def inference(loss_data, opt_state):
188
203
data [:, 1 , :] = positions [:, 2 , :] - positions [:, 0 , :]
189
204
positions = data .copy ()
190
205
191
- forces = forces [:, 1 :, :] # Select only the forces on the hydrogen atoms since the oxygen is fixed
206
+ forces = forces [
207
+ :, 1 :, :
208
+ ] # Select only the forces on the hydrogen atoms since the oxygen is fixed
192
209
193
210
194
211
# Splitting in train-test set
195
- indices_train = np .random .choice (np .arange (shape [0 ]), size = int (0.8 * shape [0 ]), replace = False )
212
+ indices_train = np .random .choice (
213
+ np .arange (shape [0 ]), size = int (0.8 * shape [0 ]), replace = False
214
+ )
196
215
indices_test = np .setdiff1d (np .arange (shape [0 ]), indices_train )
197
216
198
217
E_train , E_test = (energy [indices_train , 0 ], energy [indices_test , 0 ])
@@ -215,7 +234,9 @@ def inference(loss_data, opt_state):
215
234
np .random .seed (42 )
216
235
epsilon = jnp .array (np .random .normal (0 , 0.001 , size = (D , num_qubits )))
217
236
epsilon = None # We disable SB for this specific example
218
- epsilon = jax .lax .stop_gradient (epsilon ) # comment if we wish to train the SB weights as well.
237
+ epsilon = jax .lax .stop_gradient (
238
+ epsilon
239
+ ) # comment if we wish to train the SB weights as well.
219
240
220
241
221
242
opt_init , opt_update , get_params = optimizers .adam (1e-2 )
@@ -224,8 +245,8 @@ def inference(loss_data, opt_state):
224
245
running_loss = []
225
246
226
247
227
- num_batches = 5000 # number of optimization steps
228
- batch_size = 256 # number of training data per batch
248
+ num_batches = 5000 # number of optimization steps
249
+ batch_size = 256 # number of training data per batch
229
250
230
251
batch = np .random .choice (np .arange (np .shape (data_train )[0 ]), batch_size , replace = False )
231
252
loss_data = data_train [batch , ...], E_train [batch , ...], F_train [batch , ...]
@@ -235,7 +256,9 @@ def inference(loss_data, opt_state):
235
256
# We call `train_step` and `inference` many times, so the speedup from qjit will be quite significant!
236
257
for ibatch in range (num_batches ):
237
258
# select a batch of training points
238
- batch = np .random .choice (np .arange (np .shape (data_train )[0 ]), batch_size , replace = False )
259
+ batch = np .random .choice (
260
+ np .arange (np .shape (data_train )[0 ]), batch_size , replace = False
261
+ )
239
262
240
263
# preparing the data
241
264
loss_data = data_train [batch , ...], E_train [batch , ...], F_train [batch , ...]
@@ -253,7 +276,7 @@ def inference(loss_data, opt_state):
253
276
254
277
### plotting ###
255
278
fontsize = 12
256
- plt .figure (figsize = (4 ,4 ))
279
+ plt .figure (figsize = (4 , 4 ))
257
280
plt .plot (history_loss [:, 0 ], "r-" , label = "training error" )
258
281
plt .plot (history_loss [:, 1 ], "b-" , label = "testing error" )
259
282
@@ -265,7 +288,7 @@ def inference(loss_data, opt_state):
265
288
plt .show ()
266
289
267
290
268
- plt .figure (figsize = (4 ,4 ))
291
+ plt .figure (figsize = (4 , 4 ))
269
292
plt .title ("Energy predictions" , fontsize = fontsize )
270
293
plt .plot (energy [indices_test ], E_pred , "ro" , label = "Test predictions" )
271
294
plt .plot (energy [indices_test ], energy [indices_test ], "k.-" , lw = 1 , label = "Exact" )
0 commit comments