-
Notifications
You must be signed in to change notification settings - Fork 213
Add a how-to for catalyst-compiling "Symmetry-invariant quantum machine learning force fields" #1222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
paul0403
wants to merge
15
commits into
master
Choose a base branch
from
eqnn_forcefield_qjit
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add a how-to for catalyst-compiling "Symmetry-invariant quantum machine learning force fields" #1222
Changes from 8 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
21f509b
Add a how-to for catalyst-compiling "Symmetry-invariant quantum machi…
paul0403 f83f690
grammar; remove unused imported
paul0403 442e9d2
move `opt_init, ... = optimizers.adam(...)` into the original location
paul0403 bbd9221
black format
paul0403 bdc6f4c
Merge remote-tracking branch 'origin/master' into eqnn_forcefield_qjit
paul0403 710b99a
Merge remote-tracking branch 'origin/master' into eqnn_forcefield_qjit
paul0403 c66e485
add qjit to original demo
paul0403 c72a374
remove second qjit demo
paul0403 58898d4
Merge remote-tracking branch 'origin/master' into eqnn_forcefield_qjit
paul0403 2d9ef0c
apply PR suggestions
paul0403 b4d58e7
Merge remote-tracking branch 'origin/master' into eqnn_forcefield_qjit
paul0403 d7845e4
update date of last modification
paul0403 f6adfd5
Merge remote-tracking branch 'origin/master' into eqnn_forcefield_qjit
paul0403 23e072d
trigger CI
isaacdevlugt 1823aa8
try CI with smaller run
paul0403 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -143,6 +143,10 @@ | |
import matplotlib.pyplot as plt | ||
import sklearn | ||
|
||
###################################################################### | ||
# To speed up the computation, we also import catalyst, a jit compiler for PennyLane quantum programs. | ||
import catalyst | ||
|
||
###################################################################### | ||
# Let us construct Pauli matrices, which are used to build the Hamiltonian. | ||
X = np.array([[0, 1], [1, 0]]) | ||
|
@@ -301,10 +305,13 @@ def noise_layer(epsilon, wires): | |
################################# | ||
|
||
|
||
dev = qml.device("default.qubit", wires=num_qubits) | ||
###################################################################### | ||
# To speed up the computation, we will be using catalyst to compile our quantum program, and we will be | ||
paul0403 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# running our program on the lightning backend instead of the default qubit backend. | ||
paul0403 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dev = qml.device("lightning.qubit", wires=num_qubits) | ||
|
||
|
||
@qml.qnode(dev, interface="jax") | ||
@qml.qnode(dev) | ||
def vqlm(data, params): | ||
|
||
weights = params["params"]["weights"] | ||
|
@@ -396,25 +403,27 @@ def vqlm(data, params): | |
) | ||
|
||
################################# | ||
# We will know define the cost function and how to train the model using Jax. We will use the mean-square-error loss function. | ||
# To speed up the computation, we use the decorator ``@jax.jit`` to do just-in-time compilation for this execution. This means the first execution will typically take a little longer with the | ||
# benefit that all following executions will be significantly faster, see the `Jax docs on jitting <https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html>`_. | ||
# We will now define the cost function and how to train the model using Jax. We will use the mean-square-error loss function. | ||
# To speed up the computation, we use the decorator ``@catalyst.qjit`` to do just-in-time compilation for this execution. This means the first execution will typically take a little longer with the | ||
paul0403 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# benefit that all following executions will be significantly faster, see the `Catalyst documentation <https://docs.pennylane.ai/projects/catalyst/en/stable/index.html>`_. | ||
paul0403 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
################################# | ||
from jax.example_libraries import optimizers | ||
|
||
# We vectorize the model over the data points | ||
vec_vqlm = jax.vmap(vqlm, (0, None), 0) | ||
vec_vqlm = catalyst.vmap( | ||
vqlm, | ||
in_axes=(0, {"params": {"alphas": None, "epsilon": None, "weights": None}}), | ||
out_axes=0, | ||
) | ||
|
||
|
||
# Mean-squared-error loss function | ||
@jax.jit | ||
def mse_loss(predictions, targets): | ||
return jnp.mean(0.5 * (predictions - targets) ** 2) | ||
|
||
|
||
# Make prediction and compute the loss | ||
@jax.jit | ||
def cost(weights, loss_data): | ||
data, E_target, F_target = loss_data | ||
E_pred = vec_vqlm(data, weights) | ||
|
@@ -424,17 +433,19 @@ def cost(weights, loss_data): | |
|
||
|
||
# Perform one training step | ||
@jax.jit | ||
# This function will be repeatedly called, so we qjit it to exploit the saved runtime from many runs. | ||
@catalyst.qjit | ||
def train_step(step_i, opt_state, loss_data): | ||
|
||
net_params = get_params(opt_state) | ||
loss, grads = jax.value_and_grad(cost, argnums=0)(net_params, loss_data) | ||
|
||
loss = cost(net_params, loss_data) | ||
grads = catalyst.grad(cost, method="fd", h=1e-13, argnums=0)(net_params, loss_data) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When changing to catalyst, it was discovered that
To make the demo work, I had to manually change gradient method to finite difference. This causes significant performance degradation. Possible paths forward:
|
||
return loss, opt_update(step_i, grads, opt_state) | ||
|
||
|
||
# Return prediction and loss at inference times, e.g. for testing | ||
@jax.jit | ||
# This function is also repeatedly called, so qjit it. | ||
@catalyst.qjit | ||
def inference(loss_data, opt_state): | ||
|
||
data, E_target, F_target = loss_data | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.