Skip to content

Support Static and Dynamic Variables in PLxPR Programs with QJIT #1810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 20, 2025

Conversation

sengthai
Copy link
Contributor

@sengthai sengthai commented Jun 16, 2025

Context:
Control over static variables in compilation is a general necessity. At the moment, with qml.capture.enable(), additionally specifying a static variable in qjit does not work. With this epic, we want to add support for static variables in the plxpr-qjit execution pipeline.

Description of the Change:

  • Updated the trace_from_pennylane function signature to include dynamic_argnums.
  • Modified the function implementation to correctly handle dynamic arguments when calling from_plxpr.

[[sc-93318]]

@sengthai sengthai force-pushed the support-static-dynamic-var branch from 5a1f63a to 717a2f6 Compare June 16, 2025 13:58
@sengthai sengthai marked this pull request as ready for review June 16, 2025 16:09
@sengthai sengthai requested a review from paul0403 June 16, 2025 16:09
Copy link

codecov bot commented Jun 16, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 96.56%. Comparing base (b7b709b) to head (1cd4888).
Report is 1 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1810   +/-   ##
=======================================
  Coverage   96.56%   96.56%           
=======================================
  Files          85       85           
  Lines        9324     9326    +2     
  Branches      871      872    +1     
=======================================
+ Hits         9004     9006    +2     
  Misses        261      261           
  Partials       59       59           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@dime10 dime10 requested a review from a team June 16, 2025 18:26
@paul0403
Copy link
Member

This will need to wait for #1801 , by the way.

Copy link
Member

@paul0403 paul0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, simple fix 💯

Remember to add a test : )

@sengthai sengthai force-pushed the support-static-dynamic-var branch from b7130c4 to df483ab Compare June 17, 2025 19:11
@paul0403
Copy link
Member

I just realized this doesn't actually produce static IR...

module @captured_circuit_1 {
  func.func public @jit_captured_circuit_1(%arg0: tensor<f64>) -> tensor<f64> attributes {llvm.emit_c_interface} {
    %cst = stablehlo.constant dense<1.500000e+00> : tensor<f64>
    %0 = catalyst.launch_kernel @module_captured_circuit_1::@captured_circuit_1(%cst, %arg0) : (tensor<f64>, tensor<f64>) -> tensor<f64>
    return %0 : tensor<f64>
  }

    func.func public @captured_circuit_1(%arg0: tensor<f64>, %arg1: tensor<f64>) -> tensor<f64> attributes {diff_method = "parameter-shift", llvm.linkage = #llvm.linkage<internal>, qnode} {
      %c0_i64 = arith.constant 0 : i64
      quantum.device shots(%c0_i64) ["/home/paul.wang/.local/lib/python3.10/site-packages/pennylane_lightning/liblightning_qubit_catalyst.so", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
      %0 = quantum.alloc( 1) : !quantum.reg
      %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
      %extracted = tensor.extract %arg0[] : tensor<f64>
      %out_qubits = quantum.custom "RX"(%extracted) %1 : !quantum.bit
      %extracted_0 = tensor.extract %arg1[] : tensor<f64>
      %out_qubits_1 = quantum.custom "RY"(%extracted_0) %out_qubits : !quantum.bit
      %2 = quantum.namedobs %out_qubits_1[ PauliZ] : !quantum.obs
      %3 = quantum.expval %2 : f64
      %from_elements = tensor.from_elements %3 : tensor<f64>
      %4 = quantum.insert %0[ 0], %out_qubits_1 : !quantum.reg, !quantum.bit
      quantum.dealloc %4 : !quantum.reg
      quantum.device_release
      return %from_elements : tensor<f64>
    }
}

It's only static in the sense the the qnode is being called with a static number. The underlying kernel is still dynamic.

@paul0403
Copy link
Member

paul0403 commented Jun 20, 2025

https://github.com/jax-ml/jax/blob/main/jax/_src/interpreters/partial_eval.py#L2408 this line in trace_to_jaxpr_dynamic2 in jax is changing the in_type. Before this line in_type is just one tracer, after it's two. It somehow promoted the static arg into a tracer again.

However this is only happening for the plxpr path.

@paul0403
Copy link
Member

Note that the above is not from_plxpr, but the

plxpr, out_type, out_treedef = make_jaxpr2(fn, **make_jaxpr_kwargs)(*args, **kwargs)

aka the plxpr does not pick up the static arg already

@paul0403
Copy link
Member

https://github.com/jax-ml/jax/blob/main/jax/_src/interpreters/partial_eval.py#L2408 this line in trace_to_jaxpr_dynamic2 in jax is changing the in_type. Before this line in_type is just one tracer, after it's two. It somehow promoted the static arg into a tracer again.

However this is only happening for the plxpr path.

The only difference I can see before reaching this f.call_wrapped line is that in the good old catalyst frontend, on the qnode we have interface="auto", but in plxpr we have interface="jax"

@paul0403
Copy link
Member

Ok, I now think all of my troubles are because the qjit static argnums and the qnode static argnums are fighting with each other 😅

To which I strongly suggest we disallow specifying both at the same time!

@dime10
Copy link
Contributor

dime10 commented Jun 20, 2025

Ok, I now think all of my troubles are because the qjit static argnums and the qnode static argnums are fighting with each other 😅

To which I strongly suggest we disallow specifying both at the same time!

Maybe we can clear it in the qnode when it's specified in qjit?

No I guess that's impossible because we never really encounter the qnode on the catalyst side, happy to disallow this case for now :)

@paul0403
Copy link
Member

Ok, I now think all of my troubles are because the qjit static argnums and the qnode static argnums are fighting with each other 😅
To which I strongly suggest we disallow specifying both at the same time!

Maybe we can clear it in the qnode when it's specified in qjit?

No I guess that's impossible because we never really encounter the qnode on the catalyst side, happy to disallow this case for now :)

So we actually already do this for non-plxpr pipeline https://github.com/PennyLaneAI/catalyst/blob/main/frontend/catalyst/jit.py#L587

But this is not picked up by from_plxpr because it goes through its own make_jaxpr2, which sees a qnode with the original qnode static argnums. This fails because in trace_to_jaxpr_dynamic2 the static argnums followed is the one in the wrapped function. I added a comment.

But anyway, if we are fine with not explicitly disallowing two configs together, then this PR is now ready for merging. I also updated the tests to directly check the gates are using the static angles.

Copy link
Contributor

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@paul0403 paul0403 merged commit 7496256 into main Jun 20, 2025
38 checks passed
@paul0403 paul0403 deleted the support-static-dynamic-var branch June 20, 2025 22:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants