Skip to content

Crash when computing gradients of QNode that uses outer scope arguments #1608

Closed
@mehrdad2m

Description

@mehrdad2m

The following test fails and the reason for failing is the use of an argument of the main qjitted function inside the qnode.

@qjit
def workflow(x, y):

    dev = qml.device("lightning.qubit", wires=1)

    @qml.qnode(dev)
    def circuit(x):
        qml.RX(jnp.pi * x , wires=0)
        qml.RX(jnp.pi * y , wires=0)
        return qml.expval(qml.PauliY(0))

    g = grad(circuit)
    return g(x)

print(workflow(1.0, 0.25))

Traceback:

Traceback (most recent call last):
  File "/Users/mehrdad.malek/tmp/issue1339.py", line 72, in <module>
    print(workflow(1.0, 0.25))
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 529, in __call__
    requires_promotion = self.jit_compile(args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 602, in jit_compile
    self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 145, in wrapper
    return fn(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 675, in capture
    jaxpr, out_type, treedef, plugins = trace_to_jaxpr(
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 550, in trace_to_jaxpr
    jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/tracing.py", line 533, in make_jaxpr_f
    jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
  File "/Users/mehrdad.malek/tmp/issue1339.py", line 70, in workflow
    return g(x)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/api_extensions/differentiation.py", line 688, in __call__
    results = grad_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float64[].
This DynamicJaxprTracer was created on line /Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/tracing.py:533:38 (make_jaxpr2.<locals>.make_jaxpr_f)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/mehrdad.malek/tmp/issue1339.py", line 72, in <module>
    print(workflow(1.0, 0.25))
          ^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 529, in __call__
    requires_promotion = self.jit_compile(args, **kwargs)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 604, in jit_compile
    self.mlir_module = self.generate_ir()
                       ^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 145, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 696, in generate_ir
    mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 572, in lower_jaxpr_to_mlir
    mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 75, in jaxpr_to_mlir
    module, context = custom_lower_jaxpr_to_module(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 144, in custom_lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1438, in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(
                           ^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1622, in jaxpr_subcomp
    ans = lower_per_platform(rule_ctx, str(eqn.primitive),
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1730, in lower_per_platform
    return kept_rules[0](ctx, *rule_args, **rule_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_primitives.py", line 526, in _grad_lowering
    nparray = np.asarray(const)
              ^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/core.py", line 650, in __array__
    raise TracerArrayConversionError(self)
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float64[].
This DynamicJaxprTracer was created on line /Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/tracing.py:533:38 (make_jaxpr2.<locals>.make_jaxpr_f)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

Note that the failure only happens when getting the gradient of the quantum function. Therefore the following code works fine:

@qjit
def workflow(x, y):

    dev = qml.device("lightning.qubit", wires=1)

    @qml.qnode(dev)
    def circuit(x):
        qml.RX(jnp.pi * x , wires=0)
        qml.RX(jnp.pi * y , wires=0)
        return qml.expval(qml.PauliY(0))

    g = circuit
    return g(x)

print(workflow(1.0, 0.25))

returns

0.7071067811865475

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions