Skip to content

Commit a73c7c0

Browse files
Standardize all feature functions to return dictionaries
1 parent 1c40cd5 commit a73c7c0

File tree

4 files changed

+166
-48
lines changed

4 files changed

+166
-48
lines changed

pyeyesweb/low_level/equilibrium.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,28 +31,28 @@ class Equilibrium:
3131
>>> left = np.array([0, 0, 0])
3232
>>> right = np.array([400, 0, 0])
3333
>>> barycenter = np.array([200, 50, 0])
34-
>>> value, angle = eq(left, right, barycenter)
35-
>>> round(value, 2)
34+
>>> result = eq(left, right, barycenter)
35+
>>> round(result['value'], 2)
3636
0.91
37-
>>> round(angle, 1)
37+
>>> round(result['angle'], 1)
3838
0.0
3939
4040
# Using 2D coordinates (z is optional)
4141
>>> left_2d = np.array([0, 0])
4242
>>> right_2d = np.array([400, 0])
4343
>>> barycenter_2d = np.array([200, 50])
44-
>>> value_2d, angle_2d = eq(left_2d, right_2d, barycenter_2d)
45-
>>> round(value_2d, 2)
44+
>>> result_2d = eq(left_2d, right_2d, barycenter_2d)
45+
>>> round(result_2d['value'], 2)
4646
0.91
47-
>>> round(angle_2d, 1)
47+
>>> round(result_2d['angle'], 1)
4848
0.0
4949
"""
5050

5151
def __init__(self, margin_mm=100, y_weight=0.5):
5252
self.margin = margin_mm
5353
self.y_weight = y_weight
5454

55-
def __call__(self, left_foot: np.ndarray, right_foot: np.ndarray, barycenter: np.ndarray) -> tuple[float, float]:
55+
def __call__(self, left_foot: np.ndarray, right_foot: np.ndarray, barycenter: np.ndarray) -> dict:
5656
"""
5757
Evaluate the equilibrium value and ellipse angle.
5858
@@ -70,13 +70,13 @@ def __call__(self, left_foot: np.ndarray, right_foot: np.ndarray, barycenter: np
7070
7171
Returns
7272
-------
73-
value : float
74-
Equilibrium value in [0, 1].
75-
- 1 means the barycenter is perfectly at the ellipse center.
76-
- 0 means the barycenter is outside the ellipse.
77-
angle : float
78-
Orientation of the ellipse in degrees, measured counter-clockwise
79-
from the X-axis (line connecting left and right foot).
73+
dict
74+
Dictionary containing:
75+
- 'value': Equilibrium value in [0, 1].
76+
1 means the barycenter is perfectly at the ellipse center.
77+
0 means the barycenter is outside the ellipse.
78+
- 'angle': Orientation of the ellipse in degrees, measured counter-clockwise
79+
from the X-axis (line connecting left and right foot).
8080
8181
Notes
8282
-----
@@ -157,4 +157,4 @@ def __call__(self, left_foot: np.ndarray, right_foot: np.ndarray, barycenter: np
157157
else:
158158
value = 0.0
159159

160-
return max(0.0, value), np.degrees(angle)
160+
return {"value": max(0.0, value), "angle": np.degrees(angle)}

pyeyesweb/mid_level/smoothness.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ class Smoothness:
5353
>>> for value in movement_data:
5454
... window.append([value])
5555
>>>
56-
>>> sparc, jerk = smooth(window)
57-
>>> print(f"SPARC: {sparc:.3f}, Jerk RMS: {jerk:.3f}")
56+
>>> result = smooth(window)
57+
>>> print(f"SPARC: {result['sparc']:.3f}, Jerk RMS: {result['jerk_rms']:.3f}")
5858
5959
Notes
6060
-----
@@ -106,20 +106,15 @@ def __call__(self, sliding_window: SlidingWindow):
106106
107107
Returns
108108
-------
109-
sparc : float or None
110-
Spectral Arc Length (more negative = smoother).
111-
Returns None if insufficient data.
112-
jerk : float or None
113-
RMS of jerk (third derivative).
114-
Returns None if insufficient data.
115-
116-
Returns
117-
-------
118-
tuple of (float, float)
119-
(SPARC value, Jerk RMS value) or (NaN, NaN) if insufficient data.
109+
dict
110+
Dictionary containing:
111+
- 'sparc': Spectral Arc Length (more negative = smoother).
112+
Returns NaN if insufficient data.
113+
- 'jerk_rms': RMS of jerk (third derivative).
114+
Returns NaN if insufficient data.
120115
"""
121116
if len(sliding_window) < 5:
122-
return float("nan"), float("nan")
117+
return {"sparc": float("nan"), "jerk_rms": float("nan")}
123118

124119
signal, _ = sliding_window.to_array()
125120

@@ -133,4 +128,4 @@ def __call__(self, sliding_window: SlidingWindow):
133128
sparc = compute_sparc(normalized, self.rate_hz)
134129
jerk = compute_jerk_rms(filtered, self.rate_hz)
135130

136-
return sparc, jerk
131+
return {"sparc": sparc, "jerk_rms": jerk}

pyeyesweb/sync.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
sys.path.append(os.getcwd())
3131

3232
from collections import deque
33+
import numpy as np
3334

3435
from pyeyesweb.data_models.sliding_window import SlidingWindow
3536
from pyeyesweb.utils.signal_processing import bandpass_filter, compute_hilbert_phases
@@ -110,8 +111,8 @@ class Synchronization:
110111
... window.append([signal1[i], signal2[i]])
111112
>>>
112113
>>> # Compute synchronization
113-
>>> plv, status = sync(window)
114-
>>> print(f"PLV: {plv:.3f}, Status: {status}")
114+
>>> result = sync(window)
115+
>>> print(f"PLV: {result['plv']:.3f}, Status: {result['phase_status']}")
115116
116117
Notes
117118
-----
@@ -176,13 +177,13 @@ def compute_synchronization(self, signals: SlidingWindow):
176177
177178
Returns
178179
-------
179-
plv : float or None
180-
Phase Locking Value between 0 (no sync) and 1 (perfect sync).
181-
Returns None if the window is not full.
182-
phase_status : str or None
183-
If output_phase is True, returns "IN PHASE" when PLV > phase_threshold,
184-
"OUT OF PHASE" otherwise. Returns None if output_phase is False or
185-
if the window is not full.
180+
dict
181+
Dictionary containing synchronization metrics:
182+
- 'plv': Phase Locking Value between 0 (no sync) and 1 (perfect sync).
183+
Returns NaN if the window is not full.
184+
- 'phase_status': If output_phase is True, returns "IN PHASE" when PLV > phase_threshold,
185+
"OUT OF PHASE" otherwise. Returns None if output_phase is False or
186+
if the window is not full.
186187
187188
Notes
188189
-----
@@ -200,7 +201,7 @@ def compute_synchronization(self, signals: SlidingWindow):
200201
raise ValueError(f"Synchronization requires exactly 2 signal channels, got {signals._n_columns}")
201202

202203
if not signals.is_full():
203-
return float("nan"), None
204+
return {"plv": float("nan"), "phase_status": None}
204205

205206
sig, _ = signals.to_array()
206207

@@ -222,7 +223,7 @@ def compute_synchronization(self, signals: SlidingWindow):
222223
# Determine phase synchronization status based on threshold
223224
phase_status = "IN PHASE" if plv > self.phase_threshold else "OUT OF PHASE"
224225

225-
return plv, phase_status
226+
return {"plv": plv, "phase_status": phase_status}
226227

227228
def __call__(self, sliding_window: SlidingWindow):
228229
"""Compute and optionally display synchronization metrics.
@@ -236,21 +237,23 @@ def __call__(self, sliding_window: SlidingWindow):
236237
237238
Returns
238239
-------
239-
plv : float or None
240-
Phase Locking Value (0-1) or None if insufficient data.
241-
phase_status : str or None
242-
Phase status ("IN PHASE"/"OUT OF PHASE") or None.
240+
dict
241+
Dictionary containing synchronization metrics:
242+
- 'plv': Phase Locking Value (0-1) or NaN if insufficient data.
243+
- 'phase_status': Phase status ("IN PHASE"/"OUT OF PHASE") or None.
243244
244245
Output
245246
------------
246247
Prints synchronization metrics to stdout if PLV is computed successfully.
247248
Format depends on output_phase setting.
248249
"""
249-
plv, phase_status = self.compute_synchronization(sliding_window)
250+
result = self.compute_synchronization(sliding_window)
251+
plv = result["plv"]
252+
phase_status = result["phase_status"]
250253

251-
if plv is not None:
254+
if not np.isnan(plv):
252255
if self.output_phase:
253256
print(f"Synchronization Index: {plv:.3f}, Phase Status: {phase_status}")
254257
else:
255258
print(f"Synchronization Index: {plv:.3f}")
256-
return plv, phase_status
259+
return result

test_dict_output.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#!/usr/bin/env python3
2+
"""Test script to verify all features return dictionaries."""
3+
4+
import numpy as np
5+
from pyeyesweb.sync import Synchronization
6+
from pyeyesweb.low_level.equilibrium import Equilibrium
7+
from pyeyesweb.mid_level.smoothness import Smoothness
8+
from pyeyesweb.data_models.sliding_window import SlidingWindow
9+
10+
def test_synchronization():
11+
"""Test Synchronization class returns dictionary."""
12+
print("Testing Synchronization...")
13+
14+
# Create synchronization analyzer
15+
sync = Synchronization(sensitivity=50, output_phase=True)
16+
17+
# Create sliding window with two signals
18+
window = SlidingWindow(max_length=100, n_columns=2)
19+
20+
# Generate sample data
21+
t = np.linspace(0, 2 * np.pi, 100)
22+
signal1 = np.sin(t)
23+
signal2 = np.sin(t + np.pi/4) # Phase shifted signal
24+
25+
# Fill window
26+
for i in range(100):
27+
window.append([signal1[i], signal2[i]])
28+
29+
# Compute synchronization
30+
result = sync(window)
31+
32+
# Check that result is a dictionary
33+
assert isinstance(result, dict), f"Expected dict, got {type(result)}"
34+
assert 'plv' in result, "Missing 'plv' key in result"
35+
assert 'phase_status' in result, "Missing 'phase_status' key in result"
36+
37+
print(f" ✓ Synchronization returns dict with keys: {list(result.keys())}")
38+
print(f" PLV: {result['plv']:.3f}")
39+
print(f" Phase Status: {result['phase_status']}")
40+
41+
def test_equilibrium():
42+
"""Test Equilibrium class returns dictionary."""
43+
print("\nTesting Equilibrium...")
44+
45+
# Create equilibrium analyzer
46+
eq = Equilibrium(margin_mm=100, y_weight=0.5)
47+
48+
# Define foot positions and barycenter
49+
left_foot = np.array([0, 0, 0])
50+
right_foot = np.array([400, 0, 0])
51+
barycenter = np.array([200, 50, 0])
52+
53+
# Compute equilibrium
54+
result = eq(left_foot, right_foot, barycenter)
55+
56+
# Check that result is a dictionary
57+
assert isinstance(result, dict), f"Expected dict, got {type(result)}"
58+
assert 'value' in result, "Missing 'value' key in result"
59+
assert 'angle' in result, "Missing 'angle' key in result"
60+
61+
print(f" ✓ Equilibrium returns dict with keys: {list(result.keys())}")
62+
print(f" Value: {result['value']:.3f}")
63+
print(f" Angle: {result['angle']:.1f}°")
64+
65+
def test_smoothness():
66+
"""Test Smoothness class returns dictionary."""
67+
print("\nTesting Smoothness...")
68+
69+
# Create smoothness analyzer
70+
smooth = Smoothness(rate_hz=100.0, use_filter=True)
71+
72+
# Create sliding window
73+
window = SlidingWindow(max_length=200, n_columns=1)
74+
75+
# Generate sample movement data (smooth sinusoidal motion)
76+
t = np.linspace(0, 2, 200)
77+
movement = np.sin(2 * np.pi * t)
78+
79+
# Fill window
80+
for value in movement:
81+
window.append([value])
82+
83+
# Compute smoothness
84+
result = smooth(window)
85+
86+
# Check that result is a dictionary
87+
assert isinstance(result, dict), f"Expected dict, got {type(result)}"
88+
assert 'sparc' in result, "Missing 'sparc' key in result"
89+
assert 'jerk_rms' in result, "Missing 'jerk_rms' key in result"
90+
91+
print(f" ✓ Smoothness returns dict with keys: {list(result.keys())}")
92+
print(f" SPARC: {result['sparc']:.3f}")
93+
print(f" Jerk RMS: {result['jerk_rms']:.3f}")
94+
95+
def main():
96+
"""Run all tests."""
97+
print("=" * 50)
98+
print("Testing Feature Dictionary Output Standardization")
99+
print("=" * 50)
100+
101+
try:
102+
test_synchronization()
103+
test_equilibrium()
104+
test_smoothness()
105+
106+
print("\n" + "=" * 50)
107+
print("✅ All tests passed! All features return dictionaries.")
108+
print("=" * 50)
109+
110+
except AssertionError as e:
111+
print(f"\n❌ Test failed: {e}")
112+
return 1
113+
except Exception as e:
114+
print(f"\n❌ Unexpected error: {e}")
115+
return 1
116+
117+
return 0
118+
119+
if __name__ == "__main__":
120+
exit(main())

0 commit comments

Comments
 (0)