Skip to content

Error when using jnp.roll with QNode + Dynamic shapes #1615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
dime10 opened this issue Apr 4, 2025 · 4 comments
Open

Error when using jnp.roll with QNode + Dynamic shapes #1615

dime10 opened this issue Apr 4, 2025 · 4 comments
Labels
bug Something isn't working

Comments

@dime10
Copy link
Contributor

dime10 commented Apr 4, 2025

The following program fails:

from catalyst import *
import pennylane as qml

@qjit
@qml.qnode(qml.device('lightning.qubit', wires=0))
def circuit(N: int):
    phase_divisors = jnp.arange(N)
    cumulative_phase = jnp.roll(phase_divisors, jnp.asarray(1))
    return cumulative_phase

print(circuit(2))

Without the QNode, it succeeds.

@dime10 dime10 added the bug Something isn't working label Apr 4, 2025
@dime10 dime10 changed the title Error when using jnp.roll with QNode Error when using jnp.roll with QNode + Dynamic shapes Apr 4, 2025
@paul0403
Copy link
Member

paul0403 commented Apr 7, 2025

Tried this, I cannot reproduce the error if I add the axis for roll:

cumulative_phase = jnp.roll(phase_divisors, jnp.asarray(1), (0,))

I get the correctly shifted results, dynamically shaped. Both with and without qnode works.

Versions:

(cat_pyenv) paul.wang@L7450-H298SW3:~/catalyst_new/catalyst$ which-pennies 
amazon-braket-pennylane-plugin    1.31.2
PennyLane                         0.41.0.dev56
PennyLane-Catalyst                0.11.0.dev71 /home/paul.wang/catalyst_new/catalyst
PennyLane_Lightning               0.41.0.dev22
PennyLane_Lightning_Kokkos        0.41.0.dev22
pennylane-sphinx-theme            0.13.0
jax                               0.4.28
jaxlib                            0.4.28


(cat_pyenv) paul.wang@L7450-H298SW3:~/catalyst_new/catalyst$ git log
commit 69f1047efc734fba844f6daa2c6b49e74c3f1171
...

69f1047

@paul0403
Copy link
Member

paul0403 commented Apr 7, 2025

Tried a bit more, apparently the axis has to be specified for roll and has to be static:

@qjit
@qml.qnode(qml.device('lightning.qubit', wires=0))
def circuit(N: int, axis: int, shift_len: int):
    phase_divisors = jnp.arange(N)
    cumulative_phase = jnp.roll(phase_divisors, jnp.asarray(shift_len), (axis,))  # fine if replace `axis` with `0`
    return cumulative_phase

for i in range(10):
    print(circuit(5, 0, 1))
jax.errors.TracerIntegerConversionError: The __index__() method was called on traced array with shape int64[].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError

This fails without the qnode too.


Update:

So this led me to wonder, well, we are jnp.array-ing the shift length to make it dynamic. Why not axis too?

    cumulative_phase = jnp.roll(phase_divisors,
        jnp.asarray(shift_len),
        jnp.asarray(axis)
        )

But this gives

  ...
  File "/home/paul.wang/catalyst_new/catalyst/cat_pyenv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 5186, in roll
    axis = _ensure_index_tuple(axis)
  File "/home/paul.wang/catalyst_new/catalyst/cat_pyenv/lib/python3.10/site-packages/jax/_src/api_util.py", line 53, in _ensure_index_tuple
    x = core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
  File "/home/paul.wang/catalyst_new/catalyst/cat_pyenv/lib/python3.10/site-packages/jax/_src/core.py", line 1492, in concrete_or_error
    raise ConcretizationTypeError(val, context)
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int64[].
expected a static index or sequence of indices.

So I looked into ax/_src/numpy/lax_numpy.py", line 5186, in roll:

@dime10
Copy link
Contributor Author

dime10 commented Apr 8, 2025

@paul0403 I don't think we care about dynamic axis (that is not a paradigm I've ever seen in Catalyst programs). Let me try to confirm if I can reproduce the error again.

@paul0403
Copy link
Member

paul0403 commented Apr 9, 2025

Raised the issue upstream: jax-ml/jax#27877

Considering this error is also encountered even with just jax on its own, with no catalyst, I think we can close this issue and mark it as done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants