-
Notifications
You must be signed in to change notification settings - Fork 46
Error when using jnp.roll
with QNode + Dynamic shapes
#1615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
jnp.roll
with QNodejnp.roll
with QNode + Dynamic shapes
Tried this, I cannot reproduce the error if I add the axis for roll: cumulative_phase = jnp.roll(phase_divisors, jnp.asarray(1), (0,)) I get the correctly shifted results, dynamically shaped. Both with and without qnode works. Versions:
|
Tried a bit more, apparently the axis has to be specified for roll and has to be static: @qjit
@qml.qnode(qml.device('lightning.qubit', wires=0))
def circuit(N: int, axis: int, shift_len: int):
phase_divisors = jnp.arange(N)
cumulative_phase = jnp.roll(phase_divisors, jnp.asarray(shift_len), (axis,)) # fine if replace `axis` with `0`
return cumulative_phase
for i in range(10):
print(circuit(5, 0, 1))
This fails without the qnode too. Update: So this led me to wonder, well, we are cumulative_phase = jnp.roll(phase_divisors,
jnp.asarray(shift_len),
jnp.asarray(axis)
) But this gives
So I looked into
|
@paul0403 I don't think we care about dynamic axis (that is not a paradigm I've ever seen in Catalyst programs). Let me try to confirm if I can reproduce the error again. |
Raised the issue upstream: jax-ml/jax#27877 Considering this error is also encountered even with just jax on its own, with no catalyst, I think we can close this issue and mark it as done. |
The following program fails:
Without the QNode, it succeeds.
The text was updated successfully, but these errors were encountered: