Skip to content

Commit e14f86f

Browse files
committed
refactor strategies
1 parent 40b9fea commit e14f86f

11 files changed

+924
-266
lines changed

docs/index.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,15 @@ If you use GADD in your work, please cite:
1919
journal={arXiv preprint arXiv:2403.02294},
2020
year={2024}
2121
}
22+
23+
24+
.. toctree::
25+
:hidden:
26+
:maxdepth: 2
27+
:caption: API Reference
28+
29+
api/gadd
30+
api/sequences
31+
api/utility_functions
32+
api/group_operations
33+
api/circuit_padding

gadd/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""GADD: Genetic Algorithm for Dynamical Decoupling optimization."""
22

33
from .gadd import GADD, TrainingConfig, TrainingState, TrainingResult
4-
from .sequences import DDSequence, DDStrategy, StandardSequences
4+
from .strategies import DDSequence, DDStrategy, StandardSequences
55
from .utility_functions import (
66
UtilityFunction,
77
SuccessProbability,

gadd/circuit_padding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
PadDynamicalDecoupling,
1717
)
1818

19-
from .sequences import DDStrategy
19+
from .strategies import DDStrategy
2020

2121

2222
class DDPulse:

gadd/gadd.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
import time
1010
import os
1111

12-
import rustworkx as rx
13-
1412
import numpy as np
1513
from numpy.random import BitGenerator, Generator, SeedSequence, default_rng
1614

@@ -20,7 +18,7 @@
2018
from qiskit_ibm_runtime import Sampler
2119
import matplotlib.pyplot as plt
2220

23-
from .sequences import DDStrategy, DDSequence, StandardSequences
21+
from .strategies import DDStrategy, DDSequence, StandardSequences, ColorAssignment
2422
from .group_operations import complete_sequence_to_identity
2523
from .circuit_padding import apply_dd_strategy as _apply_dd_strategy
2624
from .utility_functions import UtilityFunction, SuccessProbability
@@ -38,7 +36,7 @@ class TrainingConfig:
3836
mutation_probability (float): Initial probability of mutation.
3937
optimization_level (int): Qiskit transpilation optimization level.
4038
shots (int): Number of shots for quantum circuit execution.
41-
num_colors (int): Number of distinct sequences per strategy (``C`` in the paper).
39+
num_colors (int): Number of distinct sequences per strategy (``k`` in the paper).
4240
group_size (int): Size of the decoupling group (``|G|`` in the paper).
4341
mode (str): Mode for generating initial population.
4442
dynamic_mutation (bool): Whether to dynamically adjust mutation probability.
@@ -112,7 +110,7 @@ def __init__(
112110
self,
113111
backend: Optional[Backend] = None,
114112
utility_function: Optional[UtilityFunction] = None,
115-
coloring: Optional[Dict] = None,
113+
coloring: Optional[Union[Dict, ColorAssignment]] = None,
116114
seed: Optional[Union[int, SeedSequence, BitGenerator, Generator]] = None,
117115
config: Optional[TrainingConfig] = None,
118116
):
@@ -126,16 +124,21 @@ def __init__(
126124

127125
# Set up coloring
128126
if coloring is None and backend is not None:
129-
# Default greedy coloring of coupling map
130-
if hasattr(backend, "coupling_map") and backend.coupling_map:
131-
self._coloring = rx.graph_greedy_color(
132-
backend.coupling_map.graph.to_undirected()
133-
)
134-
else:
135-
# Fallback: all qubits same color
136-
self._coloring = {i: 0 for i in range(backend.num_qubits)}
127+
# Default coloring from backend
128+
self._coloring = ColorAssignment(backend=backend)
129+
elif isinstance(coloring, dict):
130+
# Convert dict to ColorAssignment
131+
# Assume dict is qubit->color mapping
132+
color_to_qubits = {}
133+
for qubit, color in coloring.items():
134+
if color not in color_to_qubits:
135+
color_to_qubits[color] = []
136+
color_to_qubits[color].append(qubit)
137+
self._coloring = ColorAssignment.from_manual_assignment(color_to_qubits)
138+
elif isinstance(coloring, ColorAssignment):
139+
self._coloring = coloring
137140
else:
138-
self._coloring = coloring or {}
141+
self._coloring = None
139142

140143
# Decoupling group - matches paper's group G
141144
self._decoupling_group = ["Ip", "Im", "Xp", "Xm", "Yp", "Ym", "Zp", "Zm"]
@@ -166,11 +169,22 @@ def coloring(self):
166169

167170
@coloring.setter
168171
def coloring(self, coloring):
169-
if not isinstance(coloring, dict):
172+
if isinstance(coloring, dict):
173+
# Convert dict to ColorAssignment
174+
color_to_qubits = {}
175+
for qubit, color in coloring.items():
176+
if color not in color_to_qubits:
177+
color_to_qubits[color] = []
178+
color_to_qubits[color].append(qubit)
179+
self._coloring = ColorAssignment.from_manual_assignment(color_to_qubits)
180+
elif isinstance(coloring, ColorAssignment):
181+
self._coloring = coloring
182+
elif coloring is None:
183+
self._coloring = None
184+
else:
170185
raise TypeError(
171-
"Coloring must be a dictionary keyed by qubit index with color values"
186+
"Coloring must be a dictionary, ColorAssignment instance, or None"
172187
)
173-
self._coloring = coloring
174188

175189
def apply_dd(
176190
self,
@@ -197,13 +211,15 @@ def apply_dd(
197211
# Use provided backend or fall back to instance backend
198212
backend = backend or self._backend
199213

200-
# Get coloring for the backend
201-
if backend and hasattr(backend, "coupling_map") and backend.coupling_map:
202-
coloring = rx.graph_greedy_color(backend.coupling_map.graph.to_undirected())
214+
# Get coloring for the circuit
215+
if self._coloring is not None:
216+
coloring_dict = self._coloring.to_dict()
217+
elif backend:
218+
color_assignment = ColorAssignment(backend=backend)
219+
coloring_dict = color_assignment.to_dict()
203220
else:
204-
coloring = self._coloring or {
205-
i: 0 for i in range(target_circuit.num_qubits)
206-
}
221+
# Fallback: all qubits same color
222+
coloring_dict = {i: 0 for i in range(target_circuit.num_qubits)}
207223

208224
# Get instruction durations from backend if available
209225
instruction_durations = None
@@ -217,7 +233,7 @@ def apply_dd(
217233
return _apply_dd_strategy(
218234
target_circuit,
219235
strategy,
220-
coloring,
236+
coloring_dict,
221237
instruction_durations=instruction_durations,
222238
staggered=staggered,
223239
)

gadd/sequences.py renamed to gadd/strategies.py

Lines changed: 140 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1+
"""
2+
DD strategies, sequences, and coloring assignments.
3+
"""
4+
15
from dataclasses import dataclass
26
from typing import List, Dict, Optional, Union
7+
8+
import numpy as np
9+
import rustworkx as rx
10+
311
from qiskit.circuit.library import IGate, XGate, YGate, U1Gate, RZGate
412
from qiskit import QuantumCircuit
5-
import numpy as np
613

714

815
# Default gateset for DD pulses
@@ -211,34 +218,127 @@ def from_dict(cls, data: Dict[str, any]) -> "DDStrategy":
211218

212219

213220
class ColorAssignment:
214-
"""Assignment of device qubits to colors."""
221+
"""Assignment of device qubits to colors based on connectivity graph."""
215222

216-
def __init__(self, assignments: Dict[int, List[int]]):
223+
def __init__(
224+
self, graph: Optional[rx.PyGraph] = None, backend: Optional["Backend"] = None
225+
):
217226
"""
218227
Initialize color assignment.
219228
220229
Args:
221-
assignments: Dictionary mapping colors to lists of qubit indices
230+
graph: Connectivity graph where nodes are qubits and edges are connections.
231+
If None and backend is provided, will extract from backend.
232+
backend: Backend to extract connectivity from if graph not provided.
233+
234+
Raises:
235+
ValueError: If neither graph nor backend is provided.
222236
"""
223-
self.assignments = assignments
224-
self._validate()
237+
if graph is None and backend is None:
238+
raise ValueError("Must provide either graph or backend")
239+
240+
if graph is None:
241+
# Extract from backend
242+
if hasattr(backend, "coupling_map") and backend.coupling_map:
243+
self.graph = backend.coupling_map.graph.to_undirected()
244+
else:
245+
# Fallback: create complete graph
246+
n_qubits = backend.num_qubits if hasattr(backend, "num_qubits") else 1
247+
self.graph = rx.PyGraph()
248+
self.graph.add_nodes_from(range(n_qubits))
249+
else:
250+
self.graph = graph
251+
252+
# Perform graph coloring
253+
self._color_map = rx.graph_greedy_color(self.graph)
254+
255+
# Create assignments dictionary (color -> list of qubits)
256+
self.assignments = {}
257+
for qubit, color in self._color_map.items():
258+
if color not in self.assignments:
259+
self.assignments[color] = []
260+
self.assignments[color].append(qubit)
261+
225262
# Create reverse mapping for efficiency
226-
self._qubit_to_color = {}
263+
self._qubit_to_color = self._color_map.copy()
264+
265+
@classmethod
266+
def from_circuit(cls, circuit: QuantumCircuit) -> "ColorAssignment":
267+
"""
268+
Create color assignment from circuit connectivity.
269+
270+
Args:
271+
circuit: Quantum circuit to extract connectivity from.
272+
273+
Returns:
274+
ColorAssignment based on circuit structure.
275+
"""
276+
# Build connectivity graph from circuit
277+
graph = rx.PyGraph()
278+
qubits = set()
279+
edges = set()
280+
281+
for instruction in circuit.data:
282+
if instruction.operation.name in ["cx", "ecr", "cz"]: # Two-qubit gates
283+
qubits_involved = [q._index for q in instruction.qubits]
284+
if len(qubits_involved) == 2:
285+
q1, q2 = qubits_involved
286+
qubits.add(q1)
287+
qubits.add(q2)
288+
edges.add((min(q1, q2), max(q1, q2)))
289+
else:
290+
# Track all qubits
291+
for q in instruction.qubits:
292+
qubits.add(q._index)
293+
294+
# Add all qubits as nodes
295+
graph.add_nodes_from(sorted(qubits))
296+
297+
# Add edges
298+
for q1, q2 in edges:
299+
graph.add_edge(q1, q2, None)
300+
301+
return cls(graph=graph)
302+
303+
@classmethod
304+
def from_manual_assignment(
305+
cls, assignments: Dict[int, List[int]]
306+
) -> "ColorAssignment":
307+
"""
308+
Create from manual color assignments.
309+
310+
Args:
311+
assignments: Dictionary mapping colors to lists of qubit indices.
312+
313+
Returns:
314+
ColorAssignment with specified assignments.
315+
"""
316+
# Create a graph where qubits with different colors are connected
317+
graph = rx.PyGraph()
318+
all_qubits = set()
319+
for qubits in assignments.values():
320+
all_qubits.update(qubits)
321+
322+
graph.add_nodes_from(sorted(all_qubits))
323+
324+
# Connect qubits with different colors
325+
colors = list(assignments.keys())
326+
for i in range(len(colors)):
327+
for j in range(i + 1, len(colors)):
328+
for q1 in assignments[colors[i]]:
329+
for q2 in assignments[colors[j]]:
330+
graph.add_edge(q1, q2, None)
331+
332+
# Create instance and override the computed coloring
333+
instance = cls(graph=graph)
334+
instance.assignments = assignments
335+
instance._qubit_to_color = {}
227336
for color, qubits in assignments.items():
228337
for qubit in qubits:
229-
self._qubit_to_color[qubit] = color
338+
instance._qubit_to_color[qubit] = color
339+
instance._color_map = instance._qubit_to_color.copy()
230340

231-
def _validate(self):
232-
"""Validate color assignments."""
233-
# Check for overlapping qubit assignments
234-
assigned = set()
235-
for color, qubits in self.assignments.items():
236-
if not isinstance(qubits, list):
237-
raise TypeError(f"Qubits for color {color} must be a list")
238-
overlap = assigned.intersection(qubits)
239-
if overlap:
240-
raise ValueError(f"Qubits {overlap} assigned to multiple colors")
241-
assigned.update(qubits)
341+
return instance
242342

243343
def get_color(self, qubit: int) -> Optional[int]:
244344
"""Get color assigned to a qubit."""
@@ -252,3 +352,24 @@ def get_qubits(self, color: int) -> List[int]:
252352
def n_colors(self) -> int:
253353
"""Number of colors used in assignment."""
254354
return len(self.assignments)
355+
356+
def to_dict(self) -> Dict[int, int]:
357+
"""
358+
Convert to qubit->color mapping dictionary.
359+
360+
Returns:
361+
Dictionary mapping qubit indices to color values.
362+
"""
363+
return self._color_map.copy()
364+
365+
def validate_coloring(self) -> bool:
366+
"""
367+
Validate that the coloring is proper (no adjacent nodes have same color).
368+
369+
Returns:
370+
True if coloring is valid, False otherwise.
371+
"""
372+
for edge in self.graph.edge_list():
373+
if self._color_map[edge[0]] == self._color_map[edge[1]]:
374+
return False
375+
return True

tests/test_circuit_padding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
get_instruction_duration,
1111
apply_dd_strategy,
1212
)
13-
from gadd.sequences import DDSequence, DDStrategy
13+
from gadd.strategies import DDSequence, DDStrategy
1414

1515

1616
class TestDDPulse(unittest.TestCase):

0 commit comments

Comments
 (0)