Skip to content

[WIP] Set shots #1784

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

[WIP] Set shots #1784

wants to merge 6 commits into from

Conversation

JerryChen97
Copy link
Contributor

@JerryChen97 JerryChen97 commented Jun 4, 2025

Context:
Make non-plxpr qjit compatible with new qml.set_shots scheme: shots information will be decoupled from device

Description of the Change:

  • Apply set_shots before catalyst starts to consume device and qnode
  • If possible, apply all the user transforms as well.

Benefits:

Possible Drawbacks:

Related GitHub Issues:
[sc-90929]

@JerryChen97 JerryChen97 marked this pull request as ready for review June 4, 2025 20:08
@JerryChen97 JerryChen97 changed the title Set shots [WIP] Set shots Jun 4, 2025
Copy link
Contributor

github-actions bot commented Jun 4, 2025

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md on your branch with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

Copy link

codecov bot commented Jun 4, 2025

Codecov Report

Attention: Patch coverage is 53.84615% with 6 lines in your changes missing coverage. Please review.

Project coverage is 96.54%. Comparing base (dbc66d6) to head (947ae5f).
Report is 6 commits behind head on main.

Files with missing lines Patch % Lines
frontend/catalyst/jax_tracer.py 53.84% 1 Missing and 5 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1784      +/-   ##
==========================================
- Coverage   96.60%   96.54%   -0.07%     
==========================================
  Files          82       82              
  Lines        9211     9224      +13     
  Branches      872      878       +6     
==========================================
+ Hits         8898     8905       +7     
- Misses        254      255       +1     
- Partials       59       64       +5     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@JerryChen97 JerryChen97 self-assigned this Jun 5, 2025
@JerryChen97 JerryChen97 requested a review from paul0403 June 5, 2025 18:18
@JerryChen97
Copy link
Contributor Author

@paul0403 It's still far from being completed but I just want to know whether it's a good idea to directly modify the device._shots passed to catalyst. I chose this because this feels a minimal thing to make it happen without having any other parts of catalyst altered.

Copy link
Contributor

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @JerryChen97!

@@ -1328,6 +1328,31 @@ def trace_quantum_function(
out_tree: PyTree shapen of the result
"""

if qml.transforms.set_shots in qnode.transform_program:
Copy link
Contributor

@dime10 dime10 Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just want to know whether it's a good idea to directly modify the device._shots passed to catalyst.

I would say setting the shots on the QJITDevice like you're doing is reasonable. Although I might suggest moving this code to the QFunc class in catalyst/qfunc.py, since it's not really related to tracing and this module is already fairly beefy.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we concerned about the device at all? Isn't the point of the PL work to decouple the shots from the device? 😅 Sorry if I'm missing something here @JerryChen97

@@ -1328,6 +1328,31 @@ def trace_quantum_function(
out_tree: PyTree shapen of the result
"""

if qml.transforms.set_shots in qnode.transform_program:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I don't fully understand is what difference this approach makes, since the information is still static if it is in the qnode program. That is, the user is not able to set different shot values when calling their function, right?

I guess with this the user can set shots on a per qnode level, rather than a per device level?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is exactly what I thought as well! I don't expect this direct mutation to device to be the final solution at all. Just curious about any potential risk if this device mutation happens

Copy link
Contributor

@dime10 dime10 Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No risk I think because each qnode triggers the creation of a new QJITDevice, the scoping there matches.

@@ -16,7 +16,7 @@ enzyme=v0.0.149

# For a custom PL version, update the package version here and at
# 'doc/requirements.txt
pennylane=0.42.0-dev33
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer needed, we updated to 48 (#1795 )

Comment on lines +1333 to +1334
user_transform = qnode.transform_program
set_shots_transform = TransformProgram()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, is the new set_shots being implemented as an entire transform in pennylane? 😨

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this be a big issue for Catalyst?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure. I was just hoping that we can replace the shots = get_device_shots(device) in this jax tracer file for however PL handles it now on the qnode...

Copy link
Contributor

@dime10 dime10 Jun 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this be a big issue for Catalyst?

@paul0403 Maybe you can get an overview of how the core team plans to handle shots going forward, that could be helpful.

But for us currently, I would say no, there is no issue with the proposed change. Whether it's the most sensible option going forward might be a different question.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On PL side our approach of how to decouple the shots from devices is also under changes. There might be a better way to do here as well. I'll update asap it's done!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants