Skip to content

Commit 21f509b

Browse files
committed
Add a how-to for catalyst-compiling "Symmetry-invariant quantum machine learning force fields"
The how-to contains the full code listing, and some primitive tutorial words.
1 parent c74fab7 commit 21f509b

File tree

1 file changed

+280
-0
lines changed

1 file changed

+280
-0
lines changed
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
r"""Compiling "Symmetry-invariant quantum machine learning force fields" with PennyLane-Catalyst
2+
========================================================
3+
4+
5+
To speed up our training process, we can use PennyLane-Catalyst to compile our training workflow.
6+
7+
As opposed to jax.jit, catalyst.qjit innately understands PennyLane quantum instructions and performs better on
8+
Lightning backend.
9+
10+
11+
"""
12+
13+
import pennylane as qml
14+
import numpy as np
15+
16+
import jax
17+
from jax import numpy as jnp
18+
19+
import scipy
20+
import matplotlib.pyplot as plt
21+
import sklearn
22+
23+
import catalyst
24+
from catalyst import qjit
25+
26+
from jax.example_libraries import optimizers
27+
from sklearn.preprocessing import MinMaxScaler
28+
29+
X = np.array([[0, 1], [1, 0]])
30+
Y = np.array([[0, -1.0j], [1.0j, 0]])
31+
Z = np.array([[1, 0], [0, -1]])
32+
33+
sigmas = jnp.array(np.array([X, Y, Z])) # Vector of Pauli matrices
34+
sigmas_sigmas = jnp.array(
35+
np.array(
36+
[np.kron(X, X), np.kron(Y, Y), np.kron(Z, Z)] # Vector of tensor products of Pauli matrices
37+
)
38+
)
39+
40+
def singlet(wires):
41+
# Encode a 2-qubit rotation-invariant initial state, i.e., the singlet state.
42+
qml.Hadamard(wires=wires[0])
43+
qml.PauliZ(wires=wires[0])
44+
qml.PauliX(wires=wires[1])
45+
qml.CNOT(wires=wires)
46+
47+
48+
def equivariant_encoding(alpha, data, wires):
49+
# data (jax array): cartesian coordinates of atom i
50+
# alpha (jax array): trainable scaling parameter
51+
hamiltonian = jnp.einsum("i,ijk", data, sigmas) # Heisenberg Hamiltonian
52+
U = jax.scipy.linalg.expm(-1.0j * alpha * hamiltonian / 2)
53+
qml.QubitUnitary(U, wires=wires, id="E")
54+
55+
56+
def trainable_layer(weight, wires):
57+
hamiltonian = jnp.einsum("ijk->jk", sigmas_sigmas)
58+
U = jax.scipy.linalg.expm(-1.0j * weight * hamiltonian)
59+
qml.QubitUnitary(U, wires=wires, id="U")
60+
61+
62+
# Invariant observbale
63+
Heisenberg = [
64+
qml.PauliX(0) @ qml.PauliX(1),
65+
qml.PauliY(0) @ qml.PauliY(1),
66+
qml.PauliZ(0) @ qml.PauliZ(1),
67+
]
68+
Observable = qml.Hamiltonian(np.ones((3)), Heisenberg)
69+
70+
71+
def noise_layer(epsilon, wires):
72+
for _, w in enumerate(wires):
73+
qml.RZ(epsilon[_], wires=[w])
74+
75+
76+
D = 6 # Depth of the model
77+
B = 1 # Number of repetitions inside a trainable layer
78+
rep = 2 # Number of repeated vertical encoding
79+
80+
active_atoms = 2 # Number of active atoms
81+
# Here we only have two active atoms since we fixed the oxygen (which becomes non-active) at the origin
82+
num_qubits = active_atoms * rep
83+
84+
85+
# We need to use "lightning.qubit" device for Catalyst compilation.
86+
dev = qml.device("lightning.qubit", wires=num_qubits)
87+
88+
89+
######################################################################
90+
# The core function that is called repeatedly many times can benifit from being just-in-time compiled with qjit.
91+
# All we need to do is decorate the function with the `@qjit` decorator.
92+
#
93+
# Catalyst has its own `for_loop` function to work with qjit.
94+
# `catalyst.for_loop` should be used when the loop bounds or step depends on the qjit-ted function's input arguments.
95+
# If there is no such dependence, `catalyst.for_loop` can still be used.
96+
# Here we showcase both usages.
97+
98+
@qjit
99+
@qml.qnode(dev)
100+
def vqlm_qjit(data, params):
101+
weights = params["params"]["weights"]
102+
alphas = params["params"]["alphas"]
103+
epsilon = params["params"]["epsilon"]
104+
# Initial state
105+
@catalyst.for_loop(0, rep, 1)
106+
def singlet_loop(i):
107+
singlet(wires=jnp.arange(active_atoms)+active_atoms*i)
108+
singlet_loop()
109+
# Initial encoding
110+
for i in range(num_qubits):
111+
equivariant_encoding(
112+
alphas[i, 0], jnp.asarray(data)[i % active_atoms, ...], wires=[i]
113+
)
114+
# Reuploading model
115+
for d in range(D):
116+
qml.Barrier()
117+
for b in range(B):
118+
# Even layer
119+
for i in range(0, num_qubits - 1, 2):
120+
trainable_layer(weights[i, d + 1, b], wires=[i, (i + 1) % num_qubits])
121+
# Odd layer
122+
for i in range(1, num_qubits, 2):
123+
trainable_layer(weights[i, d + 1, b], wires=[i, (i + 1) % num_qubits])
124+
# Symmetry-breaking
125+
if epsilon is not None:
126+
noise_layer(epsilon[d, :], range(num_qubits))
127+
# Encoding
128+
for i in range(num_qubits):
129+
equivariant_encoding(
130+
alphas[i, d + 1], jnp.asarray(data)[i % active_atoms, ...], wires=[i]
131+
)
132+
return qml.expval(Observable)
133+
134+
# vectorizing for batched training with `catalyst.vmap`
135+
vec_vqlm = catalyst.vmap(vqlm_qjit, in_axes=(0, {'params': {'alphas': None, 'epsilon': None, 'weights': None}} ), out_axes=0)
136+
137+
# loss function for cost
138+
def mse_loss(predictions, targets):
139+
return jnp.mean(0.5 * (predictions - targets) ** 2)
140+
141+
# Compile a training step
142+
# many calls so compile = faster!
143+
@qjit
144+
def train_step(step_i, opt_state, loss_data):
145+
146+
def cost(weights, loss_data):
147+
data, E_target, F_target = loss_data
148+
E_pred = vec_vqlm(data, weights)
149+
l = mse_loss(E_pred, E_target)
150+
return l
151+
152+
net_params = get_params(opt_state)
153+
loss = cost(net_params, loss_data)
154+
grads = catalyst.grad(cost, method = "fd", h=1e-13, argnums=0)(net_params, loss_data)
155+
return loss, opt_update(step_i, grads, opt_state)
156+
157+
158+
# Return prediction and loss at inference times, e.g. for testing
159+
@qjit
160+
def inference(loss_data, opt_state):
161+
data, E_target, F_target = loss_data
162+
net_params = get_params(opt_state)
163+
E_pred = vec_vqlm(data, net_params)
164+
l = mse_loss(E_pred, E_target)
165+
return E_pred, l
166+
167+
168+
#################### main ##########################
169+
### setup ###
170+
# Load the data
171+
energy = np.load("eqnn_force_field_data/Energy.npy")
172+
forces = np.load("eqnn_force_field_data/Forces.npy")
173+
positions = np.load(
174+
"eqnn_force_field_data/Positions.npy"
175+
) # Cartesian coordinates shape = (nbr_sample, nbr_atoms,3)
176+
shape = np.shape(positions)
177+
178+
179+
### Scaling the energy to fit in [-1,1]
180+
181+
scaler = MinMaxScaler((-1, 1))
182+
183+
energy = scaler.fit_transform(energy)
184+
forces = forces * scaler.scale_
185+
186+
187+
# Placing the oxygen at the origin
188+
data = np.zeros((shape[0], 2, 3))
189+
data[:, 0, :] = positions[:, 1, :] - positions[:, 0, :]
190+
data[:, 1, :] = positions[:, 2, :] - positions[:, 0, :]
191+
positions = data.copy()
192+
193+
forces = forces[:, 1:, :] # Select only the forces on the hydrogen atoms since the oxygen is fixed
194+
195+
196+
# Splitting in train-test set
197+
indices_train = np.random.choice(np.arange(shape[0]), size=int(0.8 * shape[0]), replace=False)
198+
indices_test = np.setdiff1d(np.arange(shape[0]), indices_train)
199+
200+
E_train, E_test = (energy[indices_train, 0], energy[indices_test, 0])
201+
F_train, F_test = forces[indices_train, ...], forces[indices_test, ...]
202+
data_train, data_test = (
203+
jnp.array(positions[indices_train, ...]),
204+
jnp.array(positions[indices_test, ...]),
205+
)
206+
207+
### training ###
208+
opt_init, opt_update, get_params = optimizers.adam(1e-2)
209+
210+
np.random.seed(42)
211+
weights = np.zeros((num_qubits, D, B))
212+
weights[0] = np.random.uniform(0, np.pi, 1)
213+
weights = jnp.array(weights)
214+
215+
# Encoding weights
216+
alphas = jnp.array(np.ones((num_qubits, D + 1)))
217+
218+
# Symmetry-breaking (SB)
219+
np.random.seed(42)
220+
epsilon = jnp.array(np.random.normal(0, 0.001, size=(D, num_qubits)))
221+
epsilon = None # We disable SB for this specific example
222+
epsilon = jax.lax.stop_gradient(epsilon) # comment if we wish to train the SB weights as well.
223+
224+
225+
226+
net_params = {"params": {"weights": weights, "alphas": alphas, "epsilon": epsilon}}
227+
opt_state = opt_init(net_params)
228+
running_loss = []
229+
230+
231+
num_batches = 5000 # number of optimization steps
232+
batch_size = 256 # number of training data per batch
233+
234+
batch = np.random.choice(np.arange(np.shape(data_train)[0]), batch_size, replace=False)
235+
loss_data = data_train[batch, ...], E_train[batch, ...], F_train[batch, ...]
236+
237+
238+
# The main training loop
239+
# We call `train_step` and `inference` many times, so the speedup from qjit will be quite significant!
240+
for ibatch in range(num_batches):
241+
# select a batch of training points
242+
batch = np.random.choice(np.arange(np.shape(data_train)[0]), batch_size, replace=False)
243+
244+
# preparing the data
245+
loss_data = data_train[batch, ...], E_train[batch, ...], F_train[batch, ...]
246+
loss_data_test = data_test, E_test, F_test
247+
248+
# perform one training step
249+
loss, opt_state = train_step(num_batches, opt_state, loss_data)
250+
251+
# computing the test loss and energy predictions
252+
E_pred, test_loss = inference(loss_data_test, opt_state)
253+
running_loss.append([float(loss), float(test_loss)])
254+
255+
256+
history_loss = np.array(running_loss)
257+
258+
### plotting ###
259+
fontsize = 12
260+
plt.figure(figsize=(4,4))
261+
plt.plot(history_loss[:, 0], "r-", label="training error")
262+
plt.plot(history_loss[:, 1], "b-", label="testing error")
263+
264+
plt.yscale("log")
265+
plt.xlabel("Optimization Steps", fontsize=fontsize)
266+
plt.ylabel("Mean Squared Error", fontsize=fontsize)
267+
plt.legend(fontsize=fontsize)
268+
plt.tight_layout()
269+
plt.show()
270+
271+
272+
plt.figure(figsize=(4,4))
273+
plt.title("Energy predictions", fontsize=fontsize)
274+
plt.plot(energy[indices_test], E_pred, "ro", label="Test predictions")
275+
plt.plot(energy[indices_test], energy[indices_test], "k.-", lw=1, label="Exact")
276+
plt.xlabel("Exact energy", fontsize=fontsize)
277+
plt.ylabel("Predicted energy", fontsize=fontsize)
278+
plt.legend(fontsize=fontsize)
279+
plt.tight_layout()
280+
plt.show()

0 commit comments

Comments
 (0)