diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 31bc82220..ab1d001bd 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -138,13 +138,13 @@ def __hash__(self): # pragma: nocover return self.hash_value -class ConcreteQbit(AbstractQbit): +class ConcreteQbit: """Concrete Qbit.""" def _qbit_lowering(aval): assert isinstance(aval, AbstractQbit) - return (ir.OpaqueType.get("quantum", "bit"),) + return ir.OpaqueType.get("quantum", "bit") # @@ -162,13 +162,13 @@ def __hash__(self): return self.hash_value -class ConcreteQreg(AbstractQreg): +class ConcreteQreg: """Concrete quantum register.""" def _qreg_lowering(aval): assert isinstance(aval, AbstractQreg) - return (ir.OpaqueType.get("quantum", "reg"),) + return ir.OpaqueType.get("quantum", "reg") # @@ -2357,3 +2357,5 @@ def _scalar_abstractify(t): pytype_aval_mappings[type] = _scalar_abstractify pytype_aval_mappings[jax._src.numpy.scalar_types._ScalarMeta] = _scalar_abstractify +pytype_aval_mappings[ConcreteQbit] = lambda _: AbstractQbit() +pytype_aval_mappings[ConcreteQreg] = lambda _: AbstractQreg()