|
| 1 | +r"""Compiling "Symmetry-invariant quantum machine learning force fields" with PennyLane-Catalyst |
| 2 | +======================================================== |
| 3 | +
|
| 4 | +
|
| 5 | +To speed up our training process, we can use PennyLane-Catalyst to compile our training workflow. |
| 6 | +
|
| 7 | +As opposed to jax.jit, catalyst.qjit innately understands PennyLane quantum instructions and performs better on |
| 8 | +Lightning backend. |
| 9 | +
|
| 10 | +
|
| 11 | +""" |
| 12 | + |
| 13 | +import pennylane as qml |
| 14 | +import numpy as np |
| 15 | + |
| 16 | +import jax |
| 17 | +from jax import numpy as jnp |
| 18 | + |
| 19 | +import scipy |
| 20 | +import matplotlib.pyplot as plt |
| 21 | +import sklearn |
| 22 | + |
| 23 | +import catalyst |
| 24 | +from catalyst import qjit |
| 25 | + |
| 26 | +from jax.example_libraries import optimizers |
| 27 | +from sklearn.preprocessing import MinMaxScaler |
| 28 | + |
| 29 | +X = np.array([[0, 1], [1, 0]]) |
| 30 | +Y = np.array([[0, -1.0j], [1.0j, 0]]) |
| 31 | +Z = np.array([[1, 0], [0, -1]]) |
| 32 | + |
| 33 | +sigmas = jnp.array(np.array([X, Y, Z])) # Vector of Pauli matrices |
| 34 | +sigmas_sigmas = jnp.array( |
| 35 | + np.array( |
| 36 | + [np.kron(X, X), np.kron(Y, Y), np.kron(Z, Z)] # Vector of tensor products of Pauli matrices |
| 37 | + ) |
| 38 | +) |
| 39 | + |
| 40 | +def singlet(wires): |
| 41 | + # Encode a 2-qubit rotation-invariant initial state, i.e., the singlet state. |
| 42 | + qml.Hadamard(wires=wires[0]) |
| 43 | + qml.PauliZ(wires=wires[0]) |
| 44 | + qml.PauliX(wires=wires[1]) |
| 45 | + qml.CNOT(wires=wires) |
| 46 | + |
| 47 | + |
| 48 | +def equivariant_encoding(alpha, data, wires): |
| 49 | + # data (jax array): cartesian coordinates of atom i |
| 50 | + # alpha (jax array): trainable scaling parameter |
| 51 | + hamiltonian = jnp.einsum("i,ijk", data, sigmas) # Heisenberg Hamiltonian |
| 52 | + U = jax.scipy.linalg.expm(-1.0j * alpha * hamiltonian / 2) |
| 53 | + qml.QubitUnitary(U, wires=wires, id="E") |
| 54 | + |
| 55 | + |
| 56 | +def trainable_layer(weight, wires): |
| 57 | + hamiltonian = jnp.einsum("ijk->jk", sigmas_sigmas) |
| 58 | + U = jax.scipy.linalg.expm(-1.0j * weight * hamiltonian) |
| 59 | + qml.QubitUnitary(U, wires=wires, id="U") |
| 60 | + |
| 61 | + |
| 62 | +# Invariant observbale |
| 63 | +Heisenberg = [ |
| 64 | + qml.PauliX(0) @ qml.PauliX(1), |
| 65 | + qml.PauliY(0) @ qml.PauliY(1), |
| 66 | + qml.PauliZ(0) @ qml.PauliZ(1), |
| 67 | +] |
| 68 | +Observable = qml.Hamiltonian(np.ones((3)), Heisenberg) |
| 69 | + |
| 70 | + |
| 71 | +def noise_layer(epsilon, wires): |
| 72 | + for _, w in enumerate(wires): |
| 73 | + qml.RZ(epsilon[_], wires=[w]) |
| 74 | + |
| 75 | + |
| 76 | +D = 6 # Depth of the model |
| 77 | +B = 1 # Number of repetitions inside a trainable layer |
| 78 | +rep = 2 # Number of repeated vertical encoding |
| 79 | + |
| 80 | +active_atoms = 2 # Number of active atoms |
| 81 | + # Here we only have two active atoms since we fixed the oxygen (which becomes non-active) at the origin |
| 82 | +num_qubits = active_atoms * rep |
| 83 | + |
| 84 | + |
| 85 | +# We need to use "lightning.qubit" device for Catalyst compilation. |
| 86 | +dev = qml.device("lightning.qubit", wires=num_qubits) |
| 87 | + |
| 88 | + |
| 89 | +###################################################################### |
| 90 | +# The core function that is called repeatedly many times can benifit from being just-in-time compiled with qjit. |
| 91 | +# All we need to do is decorate the function with the `@qjit` decorator. |
| 92 | +# |
| 93 | +# Catalyst has its own `for_loop` function to work with qjit. |
| 94 | +# `catalyst.for_loop` should be used when the loop bounds or step depends on the qjit-ted function's input arguments. |
| 95 | +# If there is no such dependence, `catalyst.for_loop` can still be used. |
| 96 | +# Here we showcase both usages. |
| 97 | + |
| 98 | +@qjit |
| 99 | +@qml.qnode(dev) |
| 100 | +def vqlm_qjit(data, params): |
| 101 | + weights = params["params"]["weights"] |
| 102 | + alphas = params["params"]["alphas"] |
| 103 | + epsilon = params["params"]["epsilon"] |
| 104 | + # Initial state |
| 105 | + @catalyst.for_loop(0, rep, 1) |
| 106 | + def singlet_loop(i): |
| 107 | + singlet(wires=jnp.arange(active_atoms)+active_atoms*i) |
| 108 | + singlet_loop() |
| 109 | + # Initial encoding |
| 110 | + for i in range(num_qubits): |
| 111 | + equivariant_encoding( |
| 112 | + alphas[i, 0], jnp.asarray(data)[i % active_atoms, ...], wires=[i] |
| 113 | + ) |
| 114 | + # Reuploading model |
| 115 | + for d in range(D): |
| 116 | + qml.Barrier() |
| 117 | + for b in range(B): |
| 118 | + # Even layer |
| 119 | + for i in range(0, num_qubits - 1, 2): |
| 120 | + trainable_layer(weights[i, d + 1, b], wires=[i, (i + 1) % num_qubits]) |
| 121 | + # Odd layer |
| 122 | + for i in range(1, num_qubits, 2): |
| 123 | + trainable_layer(weights[i, d + 1, b], wires=[i, (i + 1) % num_qubits]) |
| 124 | + # Symmetry-breaking |
| 125 | + if epsilon is not None: |
| 126 | + noise_layer(epsilon[d, :], range(num_qubits)) |
| 127 | + # Encoding |
| 128 | + for i in range(num_qubits): |
| 129 | + equivariant_encoding( |
| 130 | + alphas[i, d + 1], jnp.asarray(data)[i % active_atoms, ...], wires=[i] |
| 131 | + ) |
| 132 | + return qml.expval(Observable) |
| 133 | + |
| 134 | +# vectorizing for batched training with `catalyst.vmap` |
| 135 | +vec_vqlm = catalyst.vmap(vqlm_qjit, in_axes=(0, {'params': {'alphas': None, 'epsilon': None, 'weights': None}} ), out_axes=0) |
| 136 | + |
| 137 | +# loss function for cost |
| 138 | +def mse_loss(predictions, targets): |
| 139 | + return jnp.mean(0.5 * (predictions - targets) ** 2) |
| 140 | + |
| 141 | +# Compile a training step |
| 142 | +# many calls so compile = faster! |
| 143 | +@qjit |
| 144 | +def train_step(step_i, opt_state, loss_data): |
| 145 | + |
| 146 | + def cost(weights, loss_data): |
| 147 | + data, E_target, F_target = loss_data |
| 148 | + E_pred = vec_vqlm(data, weights) |
| 149 | + l = mse_loss(E_pred, E_target) |
| 150 | + return l |
| 151 | + |
| 152 | + net_params = get_params(opt_state) |
| 153 | + loss = cost(net_params, loss_data) |
| 154 | + grads = catalyst.grad(cost, method = "fd", h=1e-13, argnums=0)(net_params, loss_data) |
| 155 | + return loss, opt_update(step_i, grads, opt_state) |
| 156 | + |
| 157 | + |
| 158 | +# Return prediction and loss at inference times, e.g. for testing |
| 159 | +@qjit |
| 160 | +def inference(loss_data, opt_state): |
| 161 | + data, E_target, F_target = loss_data |
| 162 | + net_params = get_params(opt_state) |
| 163 | + E_pred = vec_vqlm(data, net_params) |
| 164 | + l = mse_loss(E_pred, E_target) |
| 165 | + return E_pred, l |
| 166 | + |
| 167 | + |
| 168 | +#################### main ########################## |
| 169 | +### setup ### |
| 170 | +# Load the data |
| 171 | +energy = np.load("eqnn_force_field_data/Energy.npy") |
| 172 | +forces = np.load("eqnn_force_field_data/Forces.npy") |
| 173 | +positions = np.load( |
| 174 | + "eqnn_force_field_data/Positions.npy" |
| 175 | +) # Cartesian coordinates shape = (nbr_sample, nbr_atoms,3) |
| 176 | +shape = np.shape(positions) |
| 177 | + |
| 178 | + |
| 179 | +### Scaling the energy to fit in [-1,1] |
| 180 | + |
| 181 | +scaler = MinMaxScaler((-1, 1)) |
| 182 | + |
| 183 | +energy = scaler.fit_transform(energy) |
| 184 | +forces = forces * scaler.scale_ |
| 185 | + |
| 186 | + |
| 187 | +# Placing the oxygen at the origin |
| 188 | +data = np.zeros((shape[0], 2, 3)) |
| 189 | +data[:, 0, :] = positions[:, 1, :] - positions[:, 0, :] |
| 190 | +data[:, 1, :] = positions[:, 2, :] - positions[:, 0, :] |
| 191 | +positions = data.copy() |
| 192 | + |
| 193 | +forces = forces[:, 1:, :] # Select only the forces on the hydrogen atoms since the oxygen is fixed |
| 194 | + |
| 195 | + |
| 196 | +# Splitting in train-test set |
| 197 | +indices_train = np.random.choice(np.arange(shape[0]), size=int(0.8 * shape[0]), replace=False) |
| 198 | +indices_test = np.setdiff1d(np.arange(shape[0]), indices_train) |
| 199 | + |
| 200 | +E_train, E_test = (energy[indices_train, 0], energy[indices_test, 0]) |
| 201 | +F_train, F_test = forces[indices_train, ...], forces[indices_test, ...] |
| 202 | +data_train, data_test = ( |
| 203 | + jnp.array(positions[indices_train, ...]), |
| 204 | + jnp.array(positions[indices_test, ...]), |
| 205 | +) |
| 206 | + |
| 207 | +### training ### |
| 208 | +opt_init, opt_update, get_params = optimizers.adam(1e-2) |
| 209 | + |
| 210 | +np.random.seed(42) |
| 211 | +weights = np.zeros((num_qubits, D, B)) |
| 212 | +weights[0] = np.random.uniform(0, np.pi, 1) |
| 213 | +weights = jnp.array(weights) |
| 214 | + |
| 215 | +# Encoding weights |
| 216 | +alphas = jnp.array(np.ones((num_qubits, D + 1))) |
| 217 | + |
| 218 | +# Symmetry-breaking (SB) |
| 219 | +np.random.seed(42) |
| 220 | +epsilon = jnp.array(np.random.normal(0, 0.001, size=(D, num_qubits))) |
| 221 | +epsilon = None # We disable SB for this specific example |
| 222 | +epsilon = jax.lax.stop_gradient(epsilon) # comment if we wish to train the SB weights as well. |
| 223 | + |
| 224 | + |
| 225 | + |
| 226 | +net_params = {"params": {"weights": weights, "alphas": alphas, "epsilon": epsilon}} |
| 227 | +opt_state = opt_init(net_params) |
| 228 | +running_loss = [] |
| 229 | + |
| 230 | + |
| 231 | +num_batches = 5000 # number of optimization steps |
| 232 | +batch_size = 256 # number of training data per batch |
| 233 | + |
| 234 | +batch = np.random.choice(np.arange(np.shape(data_train)[0]), batch_size, replace=False) |
| 235 | +loss_data = data_train[batch, ...], E_train[batch, ...], F_train[batch, ...] |
| 236 | + |
| 237 | + |
| 238 | +# The main training loop |
| 239 | +# We call `train_step` and `inference` many times, so the speedup from qjit will be quite significant! |
| 240 | +for ibatch in range(num_batches): |
| 241 | + # select a batch of training points |
| 242 | + batch = np.random.choice(np.arange(np.shape(data_train)[0]), batch_size, replace=False) |
| 243 | + |
| 244 | + # preparing the data |
| 245 | + loss_data = data_train[batch, ...], E_train[batch, ...], F_train[batch, ...] |
| 246 | + loss_data_test = data_test, E_test, F_test |
| 247 | + |
| 248 | + # perform one training step |
| 249 | + loss, opt_state = train_step(num_batches, opt_state, loss_data) |
| 250 | + |
| 251 | + # computing the test loss and energy predictions |
| 252 | + E_pred, test_loss = inference(loss_data_test, opt_state) |
| 253 | + running_loss.append([float(loss), float(test_loss)]) |
| 254 | + |
| 255 | + |
| 256 | +history_loss = np.array(running_loss) |
| 257 | + |
| 258 | +### plotting ### |
| 259 | +fontsize = 12 |
| 260 | +plt.figure(figsize=(4,4)) |
| 261 | +plt.plot(history_loss[:, 0], "r-", label="training error") |
| 262 | +plt.plot(history_loss[:, 1], "b-", label="testing error") |
| 263 | + |
| 264 | +plt.yscale("log") |
| 265 | +plt.xlabel("Optimization Steps", fontsize=fontsize) |
| 266 | +plt.ylabel("Mean Squared Error", fontsize=fontsize) |
| 267 | +plt.legend(fontsize=fontsize) |
| 268 | +plt.tight_layout() |
| 269 | +plt.show() |
| 270 | + |
| 271 | + |
| 272 | +plt.figure(figsize=(4,4)) |
| 273 | +plt.title("Energy predictions", fontsize=fontsize) |
| 274 | +plt.plot(energy[indices_test], E_pred, "ro", label="Test predictions") |
| 275 | +plt.plot(energy[indices_test], energy[indices_test], "k.-", lw=1, label="Exact") |
| 276 | +plt.xlabel("Exact energy", fontsize=fontsize) |
| 277 | +plt.ylabel("Predicted energy", fontsize=fontsize) |
| 278 | +plt.legend(fontsize=fontsize) |
| 279 | +plt.tight_layout() |
| 280 | +plt.show() |
0 commit comments