Skip to content

Commit d388ec9

Browse files
qian-chupre-commit-ci[bot]drammocklarsoner
authored andcommitted
Fix realign_raw and test_realign (mne-tools#11950)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel McCloy <dan@mccloy.info> Co-authored-by: Eric Larson <larson.eric.d@gmail.com>
1 parent 2bce0b5 commit d388ec9

File tree

4 files changed

+132
-40
lines changed

4 files changed

+132
-40
lines changed

doc/changes/devel.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ Enhancements
3636

3737
Bugs
3838
~~~~
39+
- Fix bugs with :func:`mne.preprocessing.realign_raw` where the start of ``other`` was incorrectly cropped; and onsets and durations in ``other.annotations`` were left unsynced with the resampled data (:gh:`11950` by :newcontrib:`Qian Chu`)
3940
- Fix bug where ``encoding`` argument was ignored when reading annotations from an EDF file (:gh:`11958` by :newcontrib:`Andrew Gilbert`)
4041
- Fix bugs with saving splits for :class:`~mne.Epochs` (:gh:`11876` by `Dmitrii Altukhov`_)
4142
- Fix bug with multi-plot 3D rendering where only one plot was updated (:gh:`11896` by `Eric Larson`_)

doc/changes/names.inc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,8 @@
418418

419419
.. _Proloy Das: https://github.com/proloyd
420420

421+
.. _Qian Chu: https://github.com/qian-chu
422+
421423
.. _Qianliang Li: https://www.dtu.dk/english/service/phonebook/person?id=126774
422424

423425
.. _Quentin Barthélemy: https://github.com/qbarthelemy

mne/preprocessing/realign.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Authors: Eric Larson <larson.eric.d@gmail.com>
2-
2+
# Qian Chu <qianchu99@gmail.com>
3+
#
34
# License: BSD-3-Clause
45

56
import numpy as np
@@ -42,7 +43,8 @@ def realign_raw(raw, other, t_raw, t_other, verbose=None):
4243
2. Crop the start of ``raw`` or ``other``, depending on which started
4344
recording first.
4445
3. Resample ``other`` to match ``raw`` based on the clock drift.
45-
4. Crop the end of ``raw`` or ``other``, depending on which stopped
46+
4. Realign the onsets and durations in ``other.annotations``.
47+
5. Crop the end of ``raw`` or ``other``, depending on which stopped
4648
recording first (and the clock drift rate).
4749
4850
This function is primarily designed to work on recordings made at the same
@@ -85,25 +87,41 @@ def realign_raw(raw, other, t_raw, t_other, verbose=None):
8587
f"{raw.times[-1] * dr_ms_s:0.1f} ms)"
8688
)
8789

88-
# 2. Crop start of recordings to match using the zero-order term
89-
msg = f"Cropping {zero_ord:0.3f} s from the start of "
90+
# 2. Crop start of recordings to match
9091
if zero_ord > 0: # need to crop start of raw to match other
91-
logger.info(msg + "raw")
92+
logger.info(f"Cropping {zero_ord:0.3f} s from the start of raw")
9293
raw.crop(zero_ord, None)
9394
t_raw -= zero_ord
9495
else: # need to crop start of other to match raw
95-
logger.info(msg + "other")
96-
other.crop(-zero_ord, None)
97-
t_other += zero_ord
96+
t_crop = zero_ord / first_ord
97+
logger.info(f"Cropping {t_crop:0.3f} s from the start of other")
98+
other.crop(-t_crop, None)
99+
t_other += t_crop
98100

99101
# 3. Resample data using the first-order term
102+
nan_ch_names = [
103+
ch for ch in other.info["ch_names"] if np.isnan(other.get_data(picks=ch)).any()
104+
]
105+
if len(nan_ch_names) > 0: # Issue warning if any channel in other has nan values
106+
warn(
107+
f"Channel(s) {', '.join(nan_ch_names)} in `other` contain NaN values. "
108+
"Resampling these channels will result in the whole channel being NaN. "
109+
"(If realigning eye-tracking data, consider using interpolate_blinks and "
110+
"passing interpolate_gaze=True)"
111+
)
100112
logger.info("Resampling other")
101113
sfreq_new = raw.info["sfreq"] * first_ord
102114
other.load_data().resample(sfreq_new, verbose=True)
103115
with other.info._unlock():
104116
other.info["sfreq"] = raw.info["sfreq"]
105117

106-
# 4. Crop the end of one of the recordings if necessary
118+
# 4. Realign the onsets and durations in other.annotations
119+
# Must happen before end cropping to avoid losing annotations
120+
logger.info("Correcting annotations in other")
121+
other.annotations.onset *= first_ord
122+
other.annotations.duration *= first_ord
123+
124+
# 5. Crop the end of one of the recordings if necessary
107125
delta = raw.times[-1] - other.times[-1]
108126
msg = f"Cropping {abs(delta):0.3f} s from the end of "
109127
if delta > 0:
Lines changed: 102 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
# Author: Mark Wronkiewicz <wronk@uw.edu>
1+
# Authors: Mark Wronkiewicz <wronk@uw.edu>
2+
# Qian Chu <qianchu99@gmail.com>
23
#
34
# License: BSD-3-Clause
45

@@ -7,12 +8,12 @@
78
from scipy.interpolate import interp1d
89
import pytest
910

10-
from mne import create_info, find_events, Epochs
11+
from mne import create_info, find_events, Epochs, Annotations
1112
from mne.io import RawArray
1213
from mne.preprocessing import realign_raw
1314

1415

15-
@pytest.mark.parametrize("ratio_other", (1.0, 0.999, 1.001)) # drifts
16+
@pytest.mark.parametrize("ratio_other", (0.9, 0.999, 1, 1.001, 1.1)) # drifts
1617
@pytest.mark.parametrize("start_raw, start_other", [(0, 0), (0, 3), (3, 0)])
1718
@pytest.mark.parametrize("stop_raw, stop_other", [(0, 0), (0, 3), (3, 0)])
1819
def test_realign(ratio_other, start_raw, start_other, stop_raw, stop_other):
@@ -22,6 +23,8 @@ def test_realign(ratio_other, start_raw, start_other, stop_raw, stop_other):
2223
duration = 50
2324
stop_raw = duration - stop_raw
2425
stop_other = duration - stop_other
26+
signal_len = 0.2
27+
box_len = 0.5
2528
signal = np.zeros(int(round((duration + 1) * sfreq)))
2629
orig_events = np.round(
2730
np.arange(max(start_raw, start_other) + 2, min(stop_raw, stop_other) - 2)
@@ -30,8 +33,10 @@ def test_realign(ratio_other, start_raw, start_other, stop_raw, stop_other):
3033
signal[orig_events] = 1.0
3134
n_events = len(orig_events)
3235
times = np.arange(len(signal)) / sfreq
33-
stim = np.convolve(signal, np.ones(int(round(0.02 * sfreq))))[: len(times)]
34-
signal = np.convolve(signal, np.hanning(int(round(0.2 * sfreq))))[: len(times)]
36+
stim = np.convolve(signal, np.ones(int(round(box_len * sfreq))))[: len(times)]
37+
signal = np.convolve(signal, np.hanning(int(round(signal_len * sfreq))))[
38+
: len(times)
39+
]
3540

3641
# construct our sampled versions of these signals (linear interp is fine)
3742
sfreq_raw = sfreq
@@ -45,45 +50,84 @@ def test_realign(ratio_other, start_raw, start_other, stop_raw, stop_other):
4550
data_raw = np.array(
4651
[
4752
interp1d(times, d, kind)(raw_times)
48-
for d, kind in ((signal, "linear"), (stim, "nearest"))
53+
for d, kind in (
54+
(stim, "nearest"),
55+
(signal, "linear"),
56+
)
4957
]
5058
)
5159
data_other = np.array(
5260
[
5361
interp1d(times, d, kind)(other_times)
54-
for d, kind in ((signal, "linear"), (stim, "nearest"))
62+
for d, kind in (
63+
(stim, "nearest"),
64+
(signal, "linear"),
65+
)
5566
]
5667
)
57-
info_raw = create_info(["raw_data", "raw_stim"], sfreq, ["eeg", "stim"])
58-
info_other = create_info(["other_data", "other_stim"], sfreq, ["eeg", "stim"])
59-
raw = RawArray(data_raw, info_raw, first_samp=111)
68+
info_raw = create_info(["raw_stim", "raw_signal"], sfreq, ["stim", "eeg"])
69+
info_other = create_info(["other_stim", "other_signal"], sfreq, ["stim", "eeg"])
70+
raw = RawArray(data_raw, info_raw, first_samp=111) # first_samp shouldn't matter
6071
other = RawArray(data_other, info_other, first_samp=222)
72+
raw.set_meas_date((0, 0)) # meas_date shouldn't matter
73+
other.set_meas_date((100, 0))
6174

62-
# naive processing
63-
evoked_raw, events_raw, _, events_other = _assert_similarity(raw, other, n_events)
64-
if start_raw == start_other: # can just naively crop
65-
a, b = data_raw[0], data_other[0]
66-
n = min(len(a), len(b))
67-
corr = np.corrcoef(a[:n], b[:n])[0, 1]
68-
min_, max_ = (0.99999, 1.0) if sfreq_raw == sfreq_other else (0.8, 0.9)
69-
assert min_ <= corr <= max_
75+
# find events and do basic checks
76+
evoked_raw, events_raw, _, events_other = _assert_similarity(
77+
raw, other, n_events, ratio_other
78+
)
79+
80+
# construct annotations
81+
onsets_raw = (events_raw[:, 0] - raw.first_samp) / raw.info["sfreq"]
82+
dur_raw = [box_len] * len(onsets_raw)
83+
desc_raw = ["raw_box"] * len(onsets_raw)
84+
annot_raw = Annotations(onsets_raw, dur_raw, desc_raw)
85+
raw.set_annotations(annot_raw)
86+
87+
onsets_other = (events_other[:, 0] - other.first_samp) / other.info["sfreq"]
88+
dur_other = [box_len * ratio_other] * len(onsets_other)
89+
desc_other = ["other_box"] * len(onsets_other)
90+
annot_other = Annotations(onsets_other, dur_other, desc_other)
91+
other.set_annotations(annot_other)
92+
93+
# onsets/offsets correspond to 0/1 transition in boxcar signals
94+
_assert_boxcar_annot_similarity(raw, other)
7095

7196
# realign
72-
t_raw = (events_raw[:, 0] - raw.first_samp) / other.info["sfreq"]
97+
t_raw = (events_raw[:, 0] - raw.first_samp) / raw.info["sfreq"]
7398
t_other = (events_other[:, 0] - other.first_samp) / other.info["sfreq"]
7499
assert duration - 10 <= len(events_raw) < duration
75100
raw_orig, other_orig = raw.copy(), other.copy()
76101
realign_raw(raw, other, t_raw, t_other)
77102

78-
# old events should still work for raw and produce the same result
79-
evoked_raw_2, _, _, _ = _assert_similarity(
80-
raw, other, n_events, events_raw=events_raw
103+
# old events should still work for raw and produce the same evoked data
104+
evoked_raw_2, events_raw, _, events_other = _assert_similarity(
105+
raw, other, n_events, ratio_other, events_raw=events_raw
81106
)
82107
assert_allclose(evoked_raw.data, evoked_raw_2.data)
83108
assert_allclose(raw.times, other.times)
109+
84110
# raw data now aligned
85-
corr = np.corrcoef(raw.get_data([0])[0], other.get_data([0])[0])[0, 1]
86-
assert 0.99 < corr <= 1.0
111+
corr = np.corrcoef(raw.get_data("data"), other.get_data("data"))
112+
assert 0.99 < corr[0, 1] <= 1.0
113+
114+
# onsets derived from stim and annotations are the same
115+
atol = 2 / sfreq
116+
assert_allclose(
117+
raw.annotations.onset, events_raw[:, 0] / raw.info["sfreq"], atol=atol
118+
)
119+
assert_allclose(
120+
other.annotations.onset, events_other[:, 0] / other.info["sfreq"], atol=atol
121+
)
122+
123+
# onsets/offsets still correspond to 0/1 transition in boxcar signals
124+
_assert_boxcar_annot_similarity(raw, other)
125+
126+
# onsets and durations now aligned
127+
onsets_raw, dur_raw, onsets_other, dur_other = _annot_to_onset_dur(raw, other)
128+
assert len(onsets_raw) == len(onsets_other) == len(events_raw)
129+
assert_allclose(onsets_raw, onsets_other, atol=atol)
130+
assert_allclose(dur_raw, dur_other, atol=atol)
87131

88132
# Degenerate conditions -- only test in one run
89133
test_degenerate = (
@@ -103,17 +147,44 @@ def test_realign(ratio_other, start_raw, start_other, stop_raw, stop_other):
103147
realign_raw(raw_orig, other_orig, raw_times + rand_times * 1000, other_times)
104148

105149

106-
def _assert_similarity(raw, other, n_events, events_raw=None):
150+
def _assert_similarity(raw, other, n_events, ratio_other, events_raw=None):
107151
if events_raw is None:
108-
events_raw = find_events(raw)
109-
events_other = find_events(other)
110-
assert len(events_raw) == n_events
111-
assert len(events_other) == n_events
152+
events_raw = find_events(raw, output="onset")
153+
events_other = find_events(other, output="onset")
154+
assert len(events_raw) == len(events_other) == n_events
112155
kwargs = dict(baseline=None, tmin=0, tmax=0.2)
113156
evoked_raw = Epochs(raw, events_raw, **kwargs).average()
114157
evoked_other = Epochs(other, events_other, **kwargs).average()
115158
assert evoked_raw.nave == evoked_other.nave == len(events_raw)
116159
assert len(evoked_raw.data) == len(evoked_other.data) == 1 # just EEG
117-
corr = np.corrcoef(evoked_raw.data[0], evoked_other.data[0])[0, 1]
118-
assert 0.9 <= corr <= 1.0
160+
if 0.99 <= ratio_other <= 1.01: # when drift is not too large
161+
corr = np.corrcoef(evoked_raw.data[0], evoked_other.data[0])[0, 1]
162+
assert 0.9 <= corr <= 1.0
119163
return evoked_raw, events_raw, evoked_other, events_other
164+
165+
166+
def _assert_boxcar_annot_similarity(raw, other):
167+
onsets_raw, dur_raw, onsets_other, dur_other = _annot_to_onset_dur(raw, other)
168+
169+
n_events = len(onsets_raw)
170+
onsets_samp_raw = raw.time_as_index(onsets_raw)
171+
offsets_samp_raw = raw.time_as_index(onsets_raw + dur_raw)
172+
assert_allclose(raw.get_data("stim")[0, onsets_samp_raw - 2], [0] * n_events)
173+
assert_allclose(raw.get_data("stim")[0, onsets_samp_raw + 2], [1] * n_events)
174+
assert_allclose(raw.get_data("stim")[0, offsets_samp_raw - 2], [1] * n_events)
175+
assert_allclose(raw.get_data("stim")[0, offsets_samp_raw + 2], [0] * n_events)
176+
onsets_samp_other = other.time_as_index(onsets_other)
177+
offsets_samp_other = other.time_as_index(onsets_other + dur_other)
178+
assert_allclose(other.get_data("stim")[0, onsets_samp_other - 2], [0] * n_events)
179+
assert_allclose(other.get_data("stim")[0, onsets_samp_other + 2], [1] * n_events)
180+
assert_allclose(other.get_data("stim")[0, offsets_samp_other - 2], [1] * n_events)
181+
assert_allclose(other.get_data("stim")[0, offsets_samp_other + 2], [0] * n_events)
182+
183+
184+
def _annot_to_onset_dur(raw, other):
185+
onsets_raw = raw.annotations.onset - raw.first_time
186+
dur_raw = raw.annotations.duration
187+
188+
onsets_other = other.annotations.onset - other.first_time
189+
dur_other = other.annotations.duration
190+
return onsets_raw, dur_raw, onsets_other, dur_other

0 commit comments

Comments
 (0)