Skip to content

Commit 3ce8bb5

Browse files
committed
Initial commit
0 parents  commit 3ce8bb5

File tree

4 files changed

+243
-0
lines changed

4 files changed

+243
-0
lines changed

README.md

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# szcore-evaluation
2+
3+
Compare szCORE compliant annotations of EEG datasets of people with epilelpsy.
4+
5+
The package compares annotations in TSV. The annotations should be organized in a BIDS compliant manner:
6+
7+
```txt
8+
BIDS_DATASET/
9+
├── ...
10+
├── sub-01/
11+
│ ├── ses-01/
12+
│ │ └── eeg/
13+
│ │ ├── sub-01_ses-01_task-szMonitoring_run-00_events.tsv
14+
│ │ ├── ...
15+
│ ├── ...
16+
├── ...
17+
```
18+
19+
The package compares hypothesis annotations to reference annotations from two folders which follow the same structure. It provides a JSON file with the overall results as an output:
20+
21+
```json
22+
{
23+
"sample_results": {
24+
"sensitivity": 0.08,
25+
"sensitivity_std": 0.04,
26+
"precision": 0.01,
27+
"precision_std": 0.01,
28+
"f1": 0.02,
29+
"f1_std": 0.01,
30+
"fpRate": 9792.41,
31+
"fpRate_std": 4566.68
32+
},
33+
"event_results": {
34+
"sensitivity": 1.0,
35+
"sensitivity_std": 0.0,
36+
"precision": 0.08,
37+
"precision_std": 0.03,
38+
"f1": 0.16,
39+
"f1_std": 0.04,
40+
"fpRate": 280.55,
41+
"fpRate_std": 0.03
42+
}
43+
}
44+
```
45+
46+
The library provides a simple interface:
47+
48+
```python
49+
def evaluate_dataset(
50+
reference: Path, hypothesis: Path, outFile: Path, avg_per_subject=True
51+
) -> dict:
52+
"""
53+
Compares two sets of seizure annotations accross a full dataset.
54+
55+
Parameters:
56+
reference (Path): The path to the folder containing the reference TSV files.
57+
hypothesis (Path): The path to the folder containing the hypothesis TSV files.
58+
outFile (Path): The path to the output JSON file where the results are saved.
59+
avg_per_subject (bool): Whether to compute average scores per subject or
60+
average across the full dataset.
61+
62+
Returns:
63+
dict. return the evaluation result. The dictionary contains the following
64+
keys: {'sample_results': {'sensitivity', 'precision', 'f1', 'fpRate',
65+
'sensitivity_std', 'precision_std', 'f1_std', 'fpRate_std'},
66+
'event_results':{...}
67+
}
68+
"""
69+
```

pyproject.toml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
[project]
2+
name = "szcore-evaluation"
3+
version = "0.1.0"
4+
description = "Compare szCORE compliant annotations of EEG datasets of people with epilelpsy."
5+
authors = [
6+
{ name = "Jonathan Dan", email = "jonathan.dan@epfl.ch" }
7+
]
8+
dependencies = [
9+
"epilepsy2bids>=0.7",
10+
"numpy>=1.26",
11+
"timescore>=0.0.5",
12+
]
13+
readme = "README.md"
14+
requires-python = ">= 3.12"
15+
16+
[build-system]
17+
requires = ["hatchling"]
18+
build-backend = "hatchling.build"
19+
20+
[tool.rye]
21+
managed = true
22+
dev-dependencies = []
23+
24+
[tool.hatch.metadata]
25+
allow-direct-references = true
26+
27+
[tool.hatch.build.targets.wheel]
28+
packages = ["src/szcore_evaluation"]

src/szcore_evaluation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
def hello() -> str:
2+
return "Hello from szcore-evaluation!"

src/szcore_evaluation/evaluate.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import json
2+
from pathlib import Path
3+
4+
from epilepsy2bids.annotations import Annotations
5+
import numpy as np
6+
from timescoring import scoring
7+
from timescoring.annotations import Annotation
8+
9+
10+
class Result(scoring._Scoring):
11+
"""Helper class built on top of scoring._Scoring that implements the sum
12+
operator between two scoring objects. The sum corresponds to the
13+
concatenation of both objects.
14+
Args:
15+
scoring (scoring._Scoring): initialized as None (all zeros) or from a
16+
scoring._Scoring object.
17+
"""
18+
19+
def __init__(self, score: scoring._Scoring = None):
20+
if score is None:
21+
self.fs = 0
22+
self.duration = 0
23+
self.numSamples = 0
24+
self.tp = 0
25+
self.fp = 0
26+
self.refTrue = 0
27+
else:
28+
self.fs = score.ref.fs
29+
self.duration = len(score.ref.mask) / score.ref.fs
30+
self.numSamples = score.numSamples
31+
self.tp = score.tp
32+
self.fp = score.fp
33+
self.refTrue = score.refTrue
34+
35+
def __add__(self, other_result: scoring._Scoring):
36+
new_result = Result()
37+
new_result.fs = other_result.fs
38+
new_result.duration = self.duration + other_result.duration
39+
new_result.numSamples = self.numSamples + other_result.numSamples
40+
new_result.tp = self.tp + other_result.tp
41+
new_result.fp = self.fp + other_result.fp
42+
new_result.refTrue = self.refTrue + other_result.refTrue
43+
44+
return new_result
45+
46+
def __iadd__(self, other_result: scoring._Scoring):
47+
self.fs = other_result.fs
48+
self.duration += other_result.duration
49+
self.numSamples += other_result.numSamples
50+
self.tp += other_result.tp
51+
self.fp += other_result.fp
52+
self.refTrue += other_result.refTrue
53+
54+
return self
55+
56+
57+
def evaluate_dataset(
58+
reference: Path, hypothesis: Path, outFile: Path, avg_per_subject=True
59+
) -> dict:
60+
"""
61+
Compares two sets of seizure annotations accross a full dataset.
62+
63+
Parameters:
64+
reference (Path): The path to the folder containing the reference TSV files.
65+
hypothesis (Path): The path to the folder containing the hypothesis TSV files.
66+
outFile (Path): The path to the output JSON file where the results are saved.
67+
avg_per_subject (bool): Whether to compute average scores per subject or
68+
average across the full dataset.
69+
70+
Returns:
71+
dict. return the evaluation result. The dictionary contains the following
72+
keys: {'sample_results': {'sensitivity', 'precision', 'f1', 'fpRate',
73+
'sensitivity_std', 'precision_std', 'f1_std', 'fpRate_std'},
74+
'event_results':{...}
75+
}
76+
"""
77+
78+
FS = 1
79+
80+
sample_results = dict()
81+
event_results = dict()
82+
for subject in Path(reference).glob("sub-*"):
83+
sample_results[subject.name] = Result()
84+
event_results[subject.name] = Result()
85+
86+
for ref_tsv in subject.glob("**/*.tsv"):
87+
# Load reference
88+
ref = Annotations.loadTsv(ref_tsv)
89+
ref = Annotation(ref.getMask(FS), FS)
90+
91+
# Load hypothesis
92+
hyp_tsv = Path(hypothesis) / ref_tsv.relative_to(reference)
93+
if hyp_tsv.exists():
94+
hyp = Annotations.loadTsv(hyp_tsv)
95+
hyp = Annotation(hyp.getMask(FS), FS)
96+
else:
97+
hyp = Annotation(np.zeros_like(ref.mask), ref.fs)
98+
99+
# Compute evaluation
100+
sample_score = scoring.SampleScoring(ref, hyp)
101+
event_score = scoring.EventScoring(ref, hyp)
102+
103+
# Store results
104+
sample_results[subject.name] += Result(sample_score)
105+
event_results[subject.name] += Result(event_score)
106+
107+
# Compute scores
108+
sample_results[subject.name].computeScores()
109+
event_results[subject.name].computeScores()
110+
111+
aggregated_sample_results = dict()
112+
aggregated_event_results = dict()
113+
if avg_per_subject:
114+
for result_builder, aggregated_result in zip(
115+
(sample_results, event_results),
116+
(aggregated_sample_results, aggregated_event_results),
117+
):
118+
for metric in ["sensitivity", "precision", "f1", "fpRate"]:
119+
aggregated_result[metric] = np.mean(
120+
[getattr(x, metric) for x in result_builder.values()]
121+
)
122+
aggregated_result[f"{metric}_std"] = np.std(
123+
[getattr(x, metric) for x in result_builder.values()]
124+
)
125+
else:
126+
for result_builder, aggregated_result in zip(
127+
(sample_results, event_results),
128+
(aggregated_sample_results, aggregated_event_results),
129+
):
130+
result_builder["cumulated"] = Result()
131+
for result in result_builder.values():
132+
result_builder["cumulated"] += result
133+
result_builder["cumulated"].computeScores()
134+
for metric in ["sensitivity", "precision", "f1", "fpRate"]:
135+
aggregated_result[metric] = getattr(result_builder["cumulated"], metric)
136+
137+
output = {
138+
"sample_results": aggregated_sample_results,
139+
"event_results": aggregated_event_results,
140+
}
141+
with open(outFile, "w") as file:
142+
json.dump(output, file, indent=2, sort_keys=False)
143+
144+
return output

0 commit comments

Comments
 (0)