Skip to content

Commit 9c033c4

Browse files
committed
tested version of pytest suite
1 parent 3ae2cbc commit 9c033c4

File tree

3 files changed

+171
-28
lines changed

3 files changed

+171
-28
lines changed

tests/conftest.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,31 @@ def dj_config():
2424
"database.user": os.environ.get("DJ_USER") or dj.config["database.user"],
2525
}
2626
)
27+
os.environ["DATABASE_PREFIX"] = "test_"
2728
return
2829

2930

3031
@pytest.fixture(autouse=True, scope="session")
3132
def pipeline():
32-
import tutorial_pipeline as pipeline
33+
from . import tutorial_pipeline as pipeline
3334

3435
yield {
3536
"lab": pipeline.lab,
3637
"subject": pipeline.subject,
3738
"session": pipeline.session,
3839
"probe": pipeline.probe,
3940
"ephys": pipeline.ephys,
41+
"ephys_report": pipeline.ephys_report,
4042
"get_ephys_root_data_dir": pipeline.get_ephys_root_data_dir,
4143
}
4244

4345
if _tear_down:
44-
pipeline.subject.Subject.delete()
46+
pipeline.ephys_report.schema.drop()
47+
pipeline.ephys.schema.drop()
48+
pipeline.probe.schema.drop()
49+
pipeline.session.schema.drop()
50+
pipeline.subject.schema.drop()
51+
pipeline.lab.schema.drop()
4552

4653

4754
@pytest.fixture(scope="session")
@@ -53,37 +60,46 @@ def insert_upstreams(pipeline):
5360
ephys = pipeline["ephys"]
5461

5562
subject.Subject.insert1(
56-
dict(subject="subject5", subject_birth_date="2023-01-01", sex="U")
63+
dict(subject="subject5", subject_birth_date="2023-01-01", sex="U"),
64+
skip_duplicates=True,
5765
)
5866

5967
session_key = dict(subject="subject5", session_datetime="2023-01-01 00:00:00")
68+
session.Session.insert1(session_key, skip_duplicates=True)
6069
session_dir = "raw/subject5/session1"
6170

62-
session.SessionDirectory.insert1(dict(**session_key, session_dir=session_dir))
63-
probe.Probe.insert1(dict(probe="714000838", probe_type="neuropixels 1.0 - 3B"))
71+
session.SessionDirectory.insert1(
72+
dict(**session_key, session_dir=session_dir), skip_duplicates=True
73+
)
74+
probe.Probe.insert1(
75+
dict(probe="714000838", probe_type="neuropixels 1.0 - 3B"), skip_duplicates=True
76+
)
6477
ephys.ProbeInsertion.insert1(
6578
dict(
66-
session_key,
79+
**session_key,
6780
insertion_number=1,
6881
probe="714000838",
69-
)
82+
),
83+
skip_duplicates=True,
7084
)
71-
yield
7285

73-
if _tear_down:
74-
subject.Subject.delete()
75-
probe.Probe.delete()
86+
return
7687

7788

7889
@pytest.fixture(scope="session")
79-
def populate_ephys_recording(pipeline, insert_upstream):
90+
def populate_ephys_recording(pipeline, insert_upstreams):
8091
ephys = pipeline["ephys"]
8192
ephys.EphysRecording.populate()
8293

83-
yield
94+
return
8495

85-
if _tear_down:
86-
ephys.EphysRecording.delete()
96+
97+
@pytest.fixture(scope="session")
98+
def populate_lfp(pipeline, insert_upstreams):
99+
ephys = pipeline["ephys"]
100+
ephys.LFP.populate()
101+
102+
return
87103

88104

89105
@pytest.fixture(scope="session")
@@ -129,25 +145,20 @@ def insert_clustering_task(pipeline, populate_ephys_recording):
129145
paramset_idx=0,
130146
task_mode="load", # load or trigger
131147
clustering_output_dir="processed/subject5/session1/probe_1/kilosort2-5_1",
132-
)
148+
),
149+
skip_duplicates=True,
133150
)
134151

135-
yield
136-
137-
if _tear_down:
138-
ephys.ClusteringParamSet.delete()
152+
return
139153

140154

141155
@pytest.fixture(scope="session")
142-
def processing(pipeline, populate_ephys_recording):
156+
def processing(pipeline, insert_clustering_task):
143157

144158
ephys = pipeline["ephys"]
159+
ephys.Clustering.populate()
145160
ephys.CuratedClustering.populate()
146-
ephys.LFP.populate()
147161
ephys.WaveformSet.populate()
162+
ephys.QualityMetrics.populate()
148163

149-
yield
150-
151-
if _tear_down:
152-
ephys.CuratedClustering.delete()
153-
ephys.LFP.delete()
164+
return

tests/test_pipeline.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import numpy as np
2+
import pandas as pd
3+
import datetime
4+
from uuid import UUID
5+
6+
7+
def test_generate_pipeline(pipeline):
8+
subject = pipeline["subject"]
9+
session = pipeline["session"]
10+
ephys = pipeline["ephys"]
11+
probe = pipeline["probe"]
12+
13+
# test elements connection from lab, subject to Session
14+
assert subject.Subject.full_table_name in session.Session.parents()
15+
16+
# test elements connection from Session to probe, ephys, ephys_report
17+
assert session.Session.full_table_name in ephys.ProbeInsertion.parents()
18+
assert probe.Probe.full_table_name in ephys.ProbeInsertion.parents()
19+
assert "spike_times" in (ephys.CuratedClustering.Unit.heading.secondary_attributes)
20+
21+
22+
def test_insert_upstreams(pipeline, insert_upstreams):
23+
"""Check number of subjects inserted into the `subject.Subject` table"""
24+
subject = pipeline["subject"]
25+
session = pipeline["session"]
26+
probe = pipeline["probe"]
27+
ephys = pipeline["ephys"]
28+
29+
assert len(subject.Subject()) == 1
30+
assert len(session.Session()) == 1
31+
assert len(probe.Probe()) == 1
32+
assert len(ephys.ProbeInsertion()) == 1
33+
34+
35+
def test_populate_ephys_recording(pipeline, populate_ephys_recording):
36+
ephys = pipeline["ephys"]
37+
38+
assert ephys.EphysRecording.fetch1() == {
39+
"subject": "subject5",
40+
"session_datetime": datetime.datetime(2023, 1, 1, 0, 0),
41+
"insertion_number": 1,
42+
"electrode_config_hash": UUID("8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee"),
43+
"acq_software": "SpikeGLX",
44+
"sampling_rate": 30000.0,
45+
"recording_datetime": datetime.datetime(2018, 7, 3, 20, 32, 28),
46+
"recording_duration": 338.666,
47+
}
48+
assert (
49+
ephys.EphysRecording.EphysFile.fetch1("file_path")
50+
== "raw/subject5/session1/probe_1/npx_g0_t0.imec.ap.meta"
51+
)
52+
53+
54+
def test_populate_lfp(pipeline, populate_lfp):
55+
ephys = pipeline["ephys"]
56+
57+
assert np.mean(ephys.LFP.fetch1("lfp_mean")) == -716.0220556825378
58+
assert len((ephys.LFP.Electrode).fetch("electrode")) == 43
59+
60+
61+
def test_insert_clustering_task(pipeline, insert_clustering_task):
62+
ephys = pipeline["ephys"]
63+
64+
assert ephys.ClusteringParamSet.fetch1("param_set_hash") == UUID(
65+
"de78cee1-526f-319e-b6d5-8a2ba04963d8"
66+
)
67+
68+
assert ephys.ClusteringTask.fetch1() == {
69+
"subject": "subject5",
70+
"session_datetime": datetime.datetime(2023, 1, 1, 0, 0),
71+
"insertion_number": 1,
72+
"paramset_idx": 0,
73+
"clustering_output_dir": "processed/subject5/session1/probe_1/kilosort2-5_1",
74+
"task_mode": "load",
75+
}
76+
77+
78+
def test_processing(pipeline, processing):
79+
80+
ephys = pipeline["ephys"]
81+
82+
# test ephys.CuratedClustering
83+
assert len(ephys.CuratedClustering.Unit & 'cluster_quality_label = "good"') == 176
84+
assert np.sum(ephys.CuratedClustering.Unit.fetch("spike_count")) == 328167
85+
# test ephys.WaveformSet
86+
waveforms = np.vstack(
87+
(ephys.WaveformSet.PeakWaveform).fetch("peak_electrode_waveform")
88+
)
89+
assert waveforms.shape == (227, 82)
90+
91+
# test ephys.QualityMetrics
92+
cluster_df = (ephys.QualityMetrics.Cluster).fetch(format="frame", order_by="unit")
93+
waveform_df = (ephys.QualityMetrics.Waveform).fetch(format="frame", order_by="unit")
94+
test_df = pd.concat([cluster_df, waveform_df], axis=1).reset_index()
95+
test_value = test_df.select_dtypes(include=[np.number]).mean().values
96+
97+
assert np.allclose(
98+
test_value,
99+
np.array(
100+
[
101+
1.00000000e00,
102+
0.00000000e00,
103+
1.13000000e02,
104+
4.26880089e00,
105+
1.24162431e00,
106+
7.17929515e-01,
107+
4.41633793e-01,
108+
3.08736082e-01,
109+
1.24039274e15,
110+
1.66763828e-02,
111+
4.33231948e00,
112+
7.12304747e-01,
113+
1.48995215e-02,
114+
7.73432472e-02,
115+
5.06451613e00,
116+
7.79528634e00,
117+
6.30182452e-01,
118+
1.19562726e02,
119+
7.90175419e-01,
120+
np.nan,
121+
8.78436780e-01,
122+
1.08028193e-01,
123+
-5.19418717e-02,
124+
2.36035242e02,
125+
7.48443665e-02,
126+
2.77550214e-02,
127+
]
128+
),
129+
rtol=1e-03,
130+
atol=1e-03,
131+
equal_nan=True,
132+
)

tests/tutorial_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import datajoint as dj
44
from element_animal import subject
55
from element_animal.subject import Subject
6-
from element_array_ephys import probe, ephys_no_curation as ephys
6+
from element_array_ephys import probe, ephys_no_curation as ephys, ephys_report
77
from element_lab import lab
88
from element_lab.lab import Lab, Location, Project, Protocol, Source, User
99
from element_lab.lab import Device as Equipment

0 commit comments

Comments
 (0)