Skip to content

Commit c0d6307

Browse files
committed
update security scan and tests
1 parent da53882 commit c0d6307

File tree

7 files changed

+165
-51
lines changed

7 files changed

+165
-51
lines changed

.github/workflows/ci.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,25 +144,25 @@ jobs:
144144
run: |
145145
python -m pip install --upgrade pip
146146
pip install -e ".[dev]"
147-
pip install safety bandit
147+
pip install pip-audit bandit
148148
149-
- name: Run safety check for known vulnerabilities
149+
- name: Run pip-audit to check for known vulnerabilities
150150
run: |
151-
safety scan --output json > safety-report.json || true
152-
safety scan --detailed-output
151+
pip-audit --format=json > pip-audit-report.json || true
152+
pip-audit -v
153153
154154
- name: Run bandit security linter
155155
run: |
156156
bandit -r gadd/ -f json -o bandit-report.json || true
157-
bandit -r gadd/ -ll
157+
bandit -r gadd/
158158
159159
- name: Upload security reports
160160
uses: actions/upload-artifact@v4
161161
if: always()
162162
with:
163163
name: security-reports
164164
path: |
165-
safety-report.json
165+
pip-audit-report.json
166166
bandit-report.json
167167
168168
# Build Documentation

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def custom_fidelity(circuit, result):
149149
best_strategy, result = gadd.train(
150150
sampler=sampler,
151151
training_circuit=circuit,
152-
utility_function=custom_fidelity
152+
utility_function=UtilityFactory.custom(custom_fidelity, "Custom Fidelity")
153153
)
154154
```
155155

gadd/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
OneNormDistance,
99
GHZUtility,
1010
UtilityFactory,
11-
create_utility_function,
1211
)
1312
from .experiments import (
1413
create_bernstein_vazirani_circuit,
@@ -32,7 +31,6 @@
3231
"OneNormDistance",
3332
"GHZUtility",
3433
"UtilityFactory",
35-
"create_utility_function",
3634
"create_bernstein_vazirani_circuit",
3735
"create_ghz_circuit",
3836
"run_bv_experiment",

gadd/experiments.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from qiskit_ibm_runtime import Sampler
88

99
from .gadd import GADD
10-
from .utility_functions import create_utility_function
10+
from .utility_functions import UtilityFactory
1111

1212

1313
def create_bernstein_vazirani_circuit(bitstring: str) -> QuantumCircuit:
@@ -90,9 +90,7 @@ def run_bv_experiment(
9090
circuit = create_bernstein_vazirani_circuit(bitstring)
9191

9292
# Create utility function
93-
utility_function = create_utility_function(
94-
circuit, utility_type="success_probability", target_state=bitstring
95-
)
93+
utility_function = UtilityFactory.success_probability(bitstring)
9694

9795
# Standard sequences to compare
9896
comparison_seqs = kwargs.pop(
@@ -144,7 +142,7 @@ def run_ghz_experiment(
144142
circuit = create_ghz_circuit(n_qubits)
145143

146144
# Create utility function
147-
utility_function = create_utility_function(circuit, utility_type="ghz")
145+
utility_function = UtilityFactory.ghz_state(n_qubits)
148146

149147
# Standard sequences to compare
150148
comparison_seqs = kwargs.pop(

gadd/utility_functions.py

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -109,41 +109,6 @@ def verify_state(self, state: str, counts: Dict[str, float]) -> None:
109109
raise ValueError("Target state must be a binary string")
110110

111111

112-
def create_utility_function(
113-
circuit: QuantumCircuit,
114-
utility_type: str = "success_probability",
115-
target_state: Optional[Union[str, int]] = None,
116-
**kwargs,
117-
) -> UtilityFunction:
118-
"""
119-
Create a utility function for a given circuit.
120-
121-
Args:
122-
circuit: Quantum circuit to evaluate.
123-
utility_type: Type of utility function ('success_probability', 'ghz', 'one_norm').
124-
target_state: Target state for success probability (defaults to all-zero state).
125-
**kwargs: Additional arguments for specific utility functions.
126-
127-
Returns:
128-
UtilityFunction instance.
129-
"""
130-
if utility_type == "success_probability":
131-
if target_state is None:
132-
target_state = "0" * circuit.num_qubits
133-
return SuccessProbability(target_state)
134-
135-
elif utility_type == "ghz":
136-
return GHZUtility(circuit.num_qubits)
137-
138-
elif utility_type == "one_norm":
139-
if "ideal_distribution" not in kwargs:
140-
raise ValueError("ideal_distribution required for one_norm utility")
141-
return OneNormDistance(kwargs["ideal_distribution"])
142-
143-
else:
144-
raise ValueError(f"Unknown utility type: {utility_type}")
145-
146-
147112
class SuccessProbability(UtilityFunction):
148113
"""Utility function based on success probability of measuring a target state."""
149114

@@ -270,7 +235,6 @@ def get_name(self) -> str:
270235
return self._name
271236

272237

273-
# Factory remains mostly the same but types are updated
274238
class UtilityFactory:
275239
"""Factory class for creating common utility functions."""
276240

@@ -291,7 +255,8 @@ def ghz_state(n_qubits: int) -> UtilityFunction:
291255

292256
@staticmethod
293257
def custom(
294-
function: Callable[[Dict[str, float]], float], name: str = "Custom Utility"
258+
function: Callable[[Dict[str, float]], float],
259+
name: str = "Custom Utility Function",
295260
) -> UtilityFunction:
296261
"""Create custom utility function."""
297262
return CustomUtility(function, name)

tests/test_group_operations.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,23 @@ def test_inverse_uniqueness(self):
181181
len(inverses), 1, f"Element {i} should have exactly one inverse"
182182
)
183183

184+
def test_invert_fallback_path(self):
185+
"""Test invert function when using multiplication search instead of inverse map."""
186+
# Create a custom group without inverse_map to force the fallback path
187+
custom_group = {
188+
"elements": {"e": 0, "a": 1},
189+
"names": {0: "e", 1: "a"},
190+
"multiplication": [
191+
[0, 1], # e * e = e, e * a = a
192+
[1, 0], # a * e = a, a * a = e
193+
],
194+
# Note: no inverse_map provided
195+
}
196+
197+
# Test that it finds inverses correctly using multiplication
198+
self.assertEqual(invert(0, custom_group), 0) # e^-1 = e
199+
self.assertEqual(invert(1, custom_group), 1) # a^-1 = a
200+
184201
def test_invert_custom_group_no_inverse(self):
185202
"""Test that invert raises ValueError for element without inverse in custom group."""
186203
# Create a custom "group" that's not actually a group (element 2 has no inverse)
@@ -199,6 +216,20 @@ def test_invert_custom_group_no_inverse(self):
199216
invert(2, custom_group)
200217
self.assertIn("No inverse found for element 2", str(context.exception))
201218

219+
def test_complete_sequence_already_identity(self):
220+
"""Test complete_sequence_to_identity when sequence already multiplies to Ip."""
221+
# Empty sequence is already identity, should return Ip
222+
self.assertEqual(complete_sequence_to_identity([]), 0)
223+
224+
# Sequence that multiplies to Ip: [Xp, Xp] = Ip
225+
self.assertEqual(complete_sequence_to_identity([2, 2]), 0)
226+
227+
# Another sequence that gives Ip: [Yp, Yp] = Ip
228+
self.assertEqual(complete_sequence_to_identity([4, 4]), 0)
229+
230+
# Sequence that gives Im: [Xp, Xm] = Im, should return Im to get back to Ip
231+
self.assertEqual(complete_sequence_to_identity([2, 3]), 1)
232+
202233

203234
if __name__ == "__main__":
204235
unittest.main()

tests/test_utility_functions.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,42 @@
33
import unittest
44
from qiskit.result import QuasiDistribution, Counts
55
from gadd.utility_functions import (
6+
normalize_counts,
67
UtilityFactory,
8+
SuccessProbability,
9+
OneNormDistance,
10+
GHZUtility,
11+
CustomUtility,
712
)
813

914

1015
class TestUtilityFunctions(unittest.TestCase):
16+
"""
17+
Test the utility function class and helper functions.
18+
"""
19+
20+
def test_normalize_counts_edge_cases(self):
21+
"""Test normalize_counts with edge cases."""
22+
# Test empty counts
23+
with self.assertRaises(ValueError) as context:
24+
normalize_counts({})
25+
self.assertIn("Empty counts", str(context.exception))
26+
27+
# Test zero total counts
28+
with self.assertRaises(ValueError) as context:
29+
normalize_counts({"00": 0, "11": 0})
30+
self.assertIn(
31+
"Total counts/probabilities must be positive", str(context.exception)
32+
)
33+
34+
# Test integer keys with different bit lengths
35+
counts = {0: 100, 15: 200, 7: 300} # 0b0, 0b1111, 0b111
36+
normalized = normalize_counts(counts)
37+
self.assertIn("0000", normalized) # Should pad to 4 bits (max needed)
38+
self.assertIn("1111", normalized)
39+
self.assertIn("0111", normalized)
40+
self.assertAlmostEqual(sum(normalized.values()), 1.0)
41+
1142
def test_success_probability(self):
1243
"""Test success probability utility function."""
1344
utility = UtilityFactory.success_probability("00")
@@ -61,6 +92,97 @@ def custom_func(counts):
6192
self.assertAlmostEqual(utility.compute(counts), 0.75)
6293
self.assertEqual(utility.get_name(), "Test Utility")
6394

95+
def test_verify_state(self):
96+
"""Test verify_state method."""
97+
utility = SuccessProbability("00")
98+
99+
# Test valid state
100+
counts = {"00": 0.5, "11": 0.5}
101+
utility.verify_state("00", counts) # Should not raise
102+
103+
# Test mismatched length
104+
with self.assertRaises(ValueError) as context:
105+
utility.verify_state("000", counts)
106+
self.assertIn("length", str(context.exception))
107+
108+
# Test non-binary string
109+
with self.assertRaises(ValueError) as context:
110+
utility.verify_state("0x", counts)
111+
self.assertIn("binary string", str(context.exception))
112+
113+
# Test empty counts
114+
with self.assertRaises(ValueError) as context:
115+
utility.verify_state("00", {})
116+
self.assertIn("Empty counts", str(context.exception))
117+
118+
def test_success_probability_integer_target(self):
119+
"""Test SuccessProbability with integer target state."""
120+
# Test with integer target
121+
utility = SuccessProbability(3) # Binary: 11
122+
counts = {"00": 0.25, "01": 0.25, "10": 0.25, "11": 0.25}
123+
self.assertAlmostEqual(utility.compute(counts), 0.25)
124+
125+
# Test that it pads correctly
126+
utility2 = SuccessProbability(1) # Binary: 1
127+
counts2 = {"000": 0.5, "001": 0.5}
128+
self.assertAlmostEqual(utility2.compute(counts2), 0.5)
129+
130+
def test_success_probability_invalid_target(self):
131+
"""Test SuccessProbability with invalid target."""
132+
with self.assertRaises(ValueError) as context:
133+
SuccessProbability("0a1")
134+
self.assertIn("binary string", str(context.exception))
135+
136+
def test_one_norm_dimension_mismatch(self):
137+
"""Test OneNormDistance with mismatched dimensions."""
138+
utility = OneNormDistance({"00": 0.5, "11": 0.5})
139+
140+
# Test with different length states
141+
counts = {"000": 0.5, "111": 0.5}
142+
with self.assertRaises(ValueError) as context:
143+
utility.compute(counts)
144+
self.assertIn("dimensions don't match", str(context.exception))
145+
146+
def test_ghz_utility_invalid_qubits(self):
147+
"""Test GHZUtility with invalid qubit number."""
148+
with self.assertRaises(ValueError) as context:
149+
GHZUtility(0)
150+
self.assertIn("positive", str(context.exception))
151+
152+
with self.assertRaises(ValueError) as context:
153+
GHZUtility(-1)
154+
self.assertIn("positive", str(context.exception))
155+
156+
def test_counts_type_variations(self):
157+
"""Test various count input types."""
158+
utility = SuccessProbability("00")
159+
160+
# Test with Counts object (mock it since we don't want to import from qiskit.result)
161+
class MockCounts(dict):
162+
pass
163+
164+
mock_counts = MockCounts({"00": 800, "01": 200})
165+
result = utility.compute(mock_counts)
166+
self.assertAlmostEqual(result, 0.8)
167+
168+
def test_all_utility_names(self):
169+
"""Test get_name methods for all utility types."""
170+
# SuccessProbability
171+
util1 = SuccessProbability("101")
172+
self.assertEqual(util1.get_name(), "Success Probability (|101⟩)")
173+
174+
# OneNormDistance
175+
util2 = OneNormDistance({"00": 1.0})
176+
self.assertEqual(util2.get_name(), "1-Norm Distance")
177+
178+
# GHZUtility
179+
util3 = GHZUtility(5)
180+
self.assertEqual(util3.get_name(), "GHZ State Fidelity")
181+
182+
# CustomUtility with custom name
183+
util4 = CustomUtility(lambda x: 0.5, "My Custom Utility")
184+
self.assertEqual(util4.get_name(), "My Custom Utility")
185+
64186

65187
if __name__ == "__main__":
66188
unittest.main()

0 commit comments

Comments
 (0)