@@ -274,9 +274,13 @@ class SubroutineInterpreter(PlxprInterpreter):
274
274
* does not allocate a new register upon beginning,
275
275
* does not deallocate the quantum register upon ending,
276
276
* and it does not release the quantum device back to the runtime.
277
+
278
+ Args:
279
+ device (qml.devices.Device)
280
+ shots (qml.measurements.Shots)
277
281
"""
278
282
279
- def __init__ (self , device , shots ):
283
+ def __init__ (self , device , shots : qml . measurements . Shots | int ):
280
284
self ._device = device
281
285
self ._shots = self ._extract_shots_value (shots )
282
286
self .stateref = None
@@ -413,10 +417,6 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list:
413
417
*args (tuple[TensorLike]): The arguments for the jaxpr.
414
418
Returns:
415
419
list[TensorLike]: the results of the execution.
416
- """
417
- raise NotImplementedError ("Unreachable code until we add subroutine feature" )
418
-
419
- """
420
420
421
421
# We assume we have at least one argument (the qreg)
422
422
assert len(args) > 0
@@ -438,9 +438,10 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list:
438
438
439
439
return outvals
440
440
"""
441
+ raise NotImplementedError ("Unreachable code until we add subroutine feature" )
441
442
442
443
443
- class QFuncPlxprInterpreter (SubroutineInterpreter , PlxprInterpreter ):
444
+ class QFuncPlxprInterpreter (SubroutineInterpreter ):
444
445
"""An interpreter that converts plxpr into catalyst-variant jaxpr.
445
446
446
447
Args:
@@ -449,9 +450,6 @@ class QFuncPlxprInterpreter(SubroutineInterpreter, PlxprInterpreter):
449
450
450
451
"""
451
452
452
- def __init__ (self , device , shots : qml .measurements .Shots | int ):
453
- super ().__init__ (device , shots )
454
-
455
453
def setup (self ):
456
454
"""Initialize the stateref and bind the device."""
457
455
if self .stateref is None :
0 commit comments