Skip to content

Commit 68f666e

Browse files
ENH: Use data-based padding instead of "odd" padding when filtering in raw.plot (#13183)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
1 parent 80ca925 commit 68f666e

File tree

8 files changed

+177
-93
lines changed

8 files changed

+177
-93
lines changed

azure-pipelines.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ stages:
113113
- bash: |
114114
set -e
115115
python -m pip install --progress-bar off --upgrade pip
116-
python -m pip install --progress-bar off "mne-qt-browser[opengl] @ git+https://github.com/mne-tools/mne-qt-browser.git@main" pyvista scikit-learn python-picard qtpy nibabel sphinx-gallery "PySide6!=6.8.0,!=6.8.0.1,!=6.8.1.1" pandas neo pymatreader antio defusedxml
116+
python -m pip install --progress-bar off "mne-qt-browser[opengl] @ git+https://github.com/mne-tools/mne-qt-browser.git" pyvista scikit-learn python-picard qtpy nibabel sphinx-gallery "PySide6!=6.8.0,!=6.8.0.1,!=6.8.1.1" pandas neo pymatreader antio defusedxml
117117
python -m pip uninstall -yq mne
118118
python -m pip install --progress-bar off --upgrade -e .[test]
119119
displayName: 'Install dependencies with pip'

doc/changes/devel/13183.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed bug with filter padding type in :func:`mne.io.Raw.plot` and related functions to reduce edge ringing during data display, by `Eric Larson`_.

mne/cuda.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -374,10 +374,9 @@ def _smart_pad(x, n_pad, pad="reflect_limited"):
374374
elif (n_pad < 0).any():
375375
raise RuntimeError("n_pad must be non-negative")
376376
if pad == "reflect_limited":
377-
# need to pad with zeros if len(x) <= npad
378377
l_z_pad = np.zeros(max(n_pad[0] - len(x) + 1, 0), dtype=x.dtype)
379378
r_z_pad = np.zeros(max(n_pad[1] - len(x) + 1, 0), dtype=x.dtype)
380-
return np.concatenate(
379+
out = np.concatenate(
381380
[
382381
l_z_pad,
383382
2 * x[0] - x[n_pad[0] : 0 : -1],
@@ -387,4 +386,8 @@ def _smart_pad(x, n_pad, pad="reflect_limited"):
387386
]
388387
)
389388
else:
390-
return np.pad(x, (tuple(n_pad),), pad)
389+
kwargs = dict()
390+
if pad == "reflect":
391+
kwargs["reflect_type"] = "odd"
392+
out = np.pad(x, (tuple(n_pad),), pad, **kwargs)
393+
return out

mne/filter.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -545,16 +545,21 @@ def _iir_filter(x, iir_params, picks, n_jobs, copy, phase="zero"):
545545
padlen = min(iir_params["padlen"], x.shape[-1] - 1)
546546
if "sos" in iir_params:
547547
fun = partial(
548-
signal.sosfiltfilt, sos=iir_params["sos"], padlen=padlen, axis=-1
548+
_iir_pad_apply_unpad,
549+
func=signal.sosfiltfilt,
550+
sos=iir_params["sos"],
551+
padlen=padlen,
552+
padtype="reflect_limited",
549553
)
550554
_check_coefficients(iir_params["sos"])
551555
else:
552556
fun = partial(
553-
signal.filtfilt,
557+
_iir_pad_apply_unpad,
558+
func=signal.filtfilt,
554559
b=iir_params["b"],
555560
a=iir_params["a"],
556561
padlen=padlen,
557-
axis=-1,
562+
padtype="reflect_limited",
558563
)
559564
_check_coefficients((iir_params["b"], iir_params["a"]))
560565
else:
@@ -2937,3 +2942,15 @@ def _filt_update_info(info, update_info, l_freq, h_freq):
29372942
):
29382943
with info._unlock():
29392944
info["highpass"] = float(l_freq)
2945+
2946+
2947+
def _iir_pad_apply_unpad(x, *, func, padlen, padtype, **kwargs):
2948+
x_out = np.reshape(x, (-1, x.shape[-1])).copy()
2949+
for this_x in x_out:
2950+
x_ext = this_x
2951+
if padlen:
2952+
x_ext = _smart_pad(x_ext, (padlen, padlen), padtype)
2953+
x_ext = func(x=x_ext, axis=-1, padlen=0, **kwargs)
2954+
this_x[:] = x_ext[padlen : len(x_ext) - padlen]
2955+
x_out.shape = x.shape
2956+
return x_out

mne/tests/test_filter.py

Lines changed: 100 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -97,82 +97,80 @@ def test_estimate_ringing():
9797
assert estimate_ringing_samples(butter(4, 0.00001)) == 100000
9898

9999

100-
def test_1d_filter():
100+
@pytest.mark.parametrize("n_signal", (1, 2, 3, 5, 10, 20, 40))
101+
@pytest.mark.parametrize("n_filter", (1, 2, 3, 5, 10, 11, 20, 21, 40, 41, 100, 101))
102+
@pytest.mark.parametrize("filter_type", ("identity", "random"))
103+
def test_1d_filter(n_signal, n_filter, filter_type):
101104
"""Test our private overlap-add filtering function."""
102105
# make some random signals and filters
103106
rng = np.random.RandomState(0)
104-
for n_signal in (1, 2, 3, 5, 10, 20, 40):
105-
x = rng.randn(n_signal)
106-
for n_filter in (1, 2, 3, 5, 10, 11, 20, 21, 40, 41, 100, 101):
107-
for filter_type in ("identity", "random"):
108-
if filter_type == "random":
109-
h = rng.randn(n_filter)
110-
else: # filter_type == 'identity'
111-
h = np.concatenate([[1.0], np.zeros(n_filter - 1)])
112-
# ensure we pad the signal the same way for both filters
113-
n_pad = n_filter - 1
114-
x_pad = _smart_pad(x, (n_pad, n_pad))
115-
for phase in ("zero", "linear", "zero-double"):
116-
# compute our expected result the slow way
117-
if phase == "zero":
118-
# only allow zero-phase for odd-length filters
119-
if n_filter % 2 == 0:
120-
pytest.raises(
121-
RuntimeError,
122-
_overlap_add_filter,
123-
x[np.newaxis],
124-
h,
125-
phase=phase,
126-
)
127-
continue
128-
shift = (len(h) - 1) // 2
129-
x_expected = np.convolve(x_pad, h)
130-
x_expected = x_expected[shift : len(x_expected) - shift]
131-
elif phase == "zero-double":
132-
shift = len(h) - 1
133-
x_expected = np.convolve(x_pad, h)
134-
x_expected = np.convolve(x_expected[::-1], h)[::-1]
135-
x_expected = x_expected[shift : len(x_expected) - shift]
136-
shift = 0
137-
else:
138-
shift = 0
139-
x_expected = np.convolve(x_pad, h)
140-
x_expected = x_expected[: len(x_expected) - len(h) + 1]
141-
# remove padding
142-
if n_pad > 0:
143-
x_expected = x_expected[n_pad : len(x_expected) - n_pad]
144-
assert len(x_expected) == len(x)
145-
# make sure we actually set things up reasonably
146-
if filter_type == "identity":
147-
out = x_pad.copy()
148-
out = out[shift + n_pad :]
149-
out = out[: len(x)]
150-
out = np.concatenate((out, np.zeros(max(len(x) - len(out), 0))))
151-
assert len(out) == len(x)
152-
assert_allclose(out, x_expected)
153-
assert len(x_expected) == len(x)
154-
155-
# compute our version
156-
for n_fft in (None, 32, 128, 129, 1023, 1024, 1025, 2048):
157-
# need to use .copy() b/c signal gets modified inplace
158-
x_copy = x[np.newaxis, :].copy()
159-
min_fft = 2 * n_filter - 1
160-
if phase == "zero-double":
161-
min_fft = 2 * min_fft - 1
162-
if n_fft is not None and n_fft < min_fft:
163-
pytest.raises(
164-
ValueError,
165-
_overlap_add_filter,
166-
x_copy,
167-
h,
168-
n_fft,
169-
phase=phase,
170-
)
171-
else:
172-
x_filtered = _overlap_add_filter(
173-
x_copy, h, n_fft, phase=phase
174-
)[0]
175-
assert_allclose(x_filtered, x_expected, atol=1e-13)
107+
x = rng.randn(n_signal)
108+
if filter_type == "random":
109+
h = rng.randn(n_filter)
110+
else: # filter_type == 'identity'
111+
h = np.concatenate([[1.0], np.zeros(n_filter - 1)])
112+
# ensure we pad the signal the same way for both filters
113+
n_pad = n_filter
114+
x_pad = _smart_pad(x, (n_pad, n_pad))
115+
for phase in ("zero", "linear", "zero-double"):
116+
# compute our expected result the slow way
117+
if phase == "zero":
118+
# only allow zero-phase for odd-length filters
119+
if n_filter % 2 == 0:
120+
pytest.raises(
121+
RuntimeError,
122+
_overlap_add_filter,
123+
x[np.newaxis],
124+
h,
125+
phase=phase,
126+
)
127+
continue
128+
shift = (len(h) - 1) // 2
129+
x_expected = np.convolve(x_pad, h)
130+
x_expected = x_expected[shift : len(x_expected) - shift]
131+
elif phase == "zero-double":
132+
shift = len(h) - 1
133+
x_expected = np.convolve(x_pad, h)
134+
x_expected = np.convolve(x_expected[::-1], h)[::-1]
135+
x_expected = x_expected[shift : len(x_expected) - shift]
136+
shift = 0
137+
else:
138+
shift = 0
139+
x_expected = np.convolve(x_pad, h)
140+
x_expected = x_expected[: len(x_expected) - len(h) + 1]
141+
# remove padding
142+
if n_pad > 0:
143+
x_expected = x_expected[n_pad : len(x_expected) - n_pad]
144+
assert len(x_expected) == len(x)
145+
# make sure we actually set things up reasonably
146+
if filter_type == "identity":
147+
out = x_pad.copy()
148+
out = out[shift + n_pad :]
149+
out = out[: len(x)]
150+
out = np.concatenate((out, np.zeros(max(len(x) - len(out), 0))))
151+
assert len(out) == len(x)
152+
assert_allclose(out, x_expected)
153+
assert len(x_expected) == len(x)
154+
155+
# compute our version
156+
for n_fft in (None, 32, 128, 129, 1023, 1024, 1025, 2048):
157+
# need to use .copy() b/c signal gets modified inplace
158+
x_copy = x[np.newaxis, :].copy()
159+
min_fft = 2 * n_filter - 1
160+
if phase == "zero-double":
161+
min_fft = 2 * min_fft - 1
162+
if n_fft is not None and n_fft < min_fft:
163+
pytest.raises(
164+
ValueError,
165+
_overlap_add_filter,
166+
x_copy,
167+
h,
168+
n_fft,
169+
phase=phase,
170+
)
171+
else:
172+
x_filtered = _overlap_add_filter(x_copy, h, n_fft, phase=phase)[0]
173+
assert_allclose(x_filtered, x_expected, atol=1e-13)
176174

177175

178176
def test_iir_stability():
@@ -1107,3 +1105,32 @@ def test_filter_minimum_phase_bug():
11071105
dB_min_half = 20 * np.log10(np.abs(H_min_half[mask]))
11081106
assert_array_less(dB_min_half, -20)
11091107
assert not (dB_min_half < -30).all()
1108+
1109+
1110+
@pytest.mark.parametrize("dc", (0, 100))
1111+
@pytest.mark.parametrize("sfreq", (1000.0, 999.0))
1112+
def test_smart_pad(dc, sfreq):
1113+
"""Test that smart pad does what it should."""
1114+
f = 1.0
1115+
t = (np.arange(0, int(round(sfreq)) + 1)) / sfreq
1116+
x = np.sin(2 * np.pi * f * t) + dc
1117+
x_want = np.r_[x, x, x]
1118+
padlen = (len(x),) * 2
1119+
x_pad = np.pad(x, padlen, mode="reflect", reflect_type="odd")
1120+
assert_allclose(x_pad, x_want, err_msg="np.pad", atol=0.1) # slight DC shift
1121+
this_x_pad = _smart_pad(x, padlen, "reflect")
1122+
# slight DC shift from out "x_want" (doubled sample at each end)
1123+
assert_allclose(this_x_pad, x_want, atol=0.1, err_msg="_smart_pad reflect x_want")
1124+
assert_allclose(
1125+
this_x_pad, x_pad, atol=1e-6, err_msg="_smart_pad reflect vs np.pad"
1126+
)
1127+
this_x_pad = _smart_pad(x, padlen, "reflect_limited")
1128+
x_want = np.r_[0, x[:-1], x, x[1:], 0]
1129+
assert_allclose(
1130+
this_x_pad, x_want, atol=1e-7, err_msg="_smart_pad reflect_limited x_want"
1131+
)
1132+
# reflect_limited uses one fewer sample so ends up a little bit different
1133+
# with even more padding
1134+
x_want = np.r_[np.zeros_like(x), x_want, np.zeros_like(x)]
1135+
x_pad = _smart_pad(x, (len(x) * 2,) * 2, "reflect_limited")
1136+
assert_allclose(x_pad, x_want, atol=0.1, err_msg="reflect_limited with zeros")

mne/viz/_figure.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# Copyright the MNE-Python contributors.
66

77
import importlib
8+
import inspect
89
from abc import ABC, abstractmethod
910
from collections import OrderedDict
1011
from contextlib import contextmanager
@@ -307,25 +308,29 @@ def _make_butterfly_selections_dict(self):
307308
def _get_start_stop(self):
308309
# update time
309310
start_sec = self.mne.t_start - self.mne.first_time
310-
stop_sec = start_sec + self.mne.duration
311311
if self.mne.is_epochs:
312312
start, stop = np.round(
313-
np.array([start_sec, stop_sec]) * self.mne.info["sfreq"]
313+
np.array([start_sec, start_sec + self.mne.duration])
314+
* self.mne.info["sfreq"]
314315
).astype(int)
315316
else:
317+
# ensure our end time includes the last sample
318+
disp_duration = (
319+
np.ceil(self.mne.duration * self.mne.info["sfreq"])
320+
/ self.mne.info["sfreq"]
321+
)
322+
stop_sec = start_sec + disp_duration
316323
start, stop = self.mne.inst.time_as_index((start_sec, stop_sec))
317324

318325
return start, stop
319326

320327
def _load_data(self, start=None, stop=None):
321328
"""Retrieve the bit of data we need for plotting."""
322329
if "raw" in (self.mne.instance_type, self.mne.ica_type):
323-
# Add additional sample to cover the case sfreq!=1000
324-
# when the shown time-range wouldn't correspond to duration anymore
325330
if stop is None:
326331
return self.mne.inst[:, start:]
327332
else:
328-
return self.mne.inst[:, start : stop + 2]
333+
return self.mne.inst[:, start:stop]
329334
else:
330335
# subtract one sample from tstart before searchsorted, to make sure
331336
# we land on the left side of the boundary time (avoid precision
@@ -362,9 +367,11 @@ def _apply_filter(self, data, start, stop, picks):
362367
)
363368
data[_picks, _start:_stop] = this_data
364369

365-
def _process_data(self, data, start, stop, picks, thread=None):
370+
def _process_data(self, data, start, stop, picks, thread=None, *, time_slice=None):
366371
"""Update self.mne.data after user interaction."""
367372
# apply projectors
373+
if time_slice is None:
374+
time_slice = slice(None)
368375
if self.mne.projector is not None:
369376
# thread is the loading-thread only available in Qt-backend
370377
if thread:
@@ -376,12 +383,13 @@ def _process_data(self, data, start, stop, picks, thread=None):
376383
if self.mne.remove_dc:
377384
if thread:
378385
thread.processText.emit("Removing DC...")
379-
data -= np.nanmean(data, axis=1, keepdims=True)
386+
data -= np.nanmean(data[..., time_slice], axis=1, keepdims=True)
380387
# apply filter
381388
if self.mne.filter_coefs is not None:
382389
if thread:
383390
thread.processText.emit("Apply Filter...")
384391
self._apply_filter(data, start, stop, picks)
392+
data = data[..., time_slice]
385393
# scale the data for display in a 1-vertical-axis-unit slot
386394
if thread:
387395
thread.processText.emit("Scale Data...")
@@ -400,12 +408,41 @@ def _process_data(self, data, start, stop, picks, thread=None):
400408

401409
return data
402410

411+
@property
412+
def _has_time_slice(self):
413+
# check that mne-qt-browser is new enough to support time_slice
414+
specs = inspect.getfullargspec(self._process_data)
415+
return "time_slice" in specs.kwonlyargs or specs.varkw
416+
403417
def _update_data(self):
404418
start, stop = self._get_start_stop()
405-
# get the data
406-
data, times = self._load_data(start, stop)
419+
# get the data, with padding if necessary
420+
kwargs = dict()
421+
padlen = None
422+
if isinstance(self.mne.filter_coefs, dict) and self._has_time_slice: # IIR
423+
padlen = self.mne.filter_coefs["padlen"]
424+
use_start = max(0, start - padlen)
425+
use_stop = min(self.mne.n_times, stop + padlen)
426+
# now during filt step, only pad as much as needed
427+
self.mne.filter_coefs["padlen"] = max(
428+
padlen - (use_stop - stop), padlen - (start - use_start)
429+
)
430+
time_slice = slice(start - use_start, start - use_start + (stop - start))
431+
kwargs["time_slice"] = time_slice
432+
else:
433+
use_start, use_stop = start, stop
434+
time_slice = slice(None)
435+
436+
data, times = self._load_data(use_start, use_stop)
437+
assert data.ndim >= 2 and data.shape[-1] == (use_stop - use_start)
407438
# process the data
408-
data = self._process_data(data, start, stop, self.mne.picks)
439+
data = self._process_data(
440+
data, use_start, use_stop, picks=self.mne.picks, **kwargs
441+
)
442+
if padlen is not None:
443+
self.mne.filter_coefs["padlen"] = padlen
444+
times = times[time_slice]
445+
assert data.ndim >= 2 and data.shape[-1] == (stop - start)
409446
# set the data as attributes
410447
self.mne.data = data
411448
self.mne.times = times

tools/circleci_dependencies.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ python -m pip install --upgrade --progress-bar off \
77
"git+https://github.com/pyvista/pyvista.git" \
88
"git+https://github.com/sphinx-gallery/sphinx-gallery.git" \
99
"git+https://github.com/mne-tools/mne-bids.git" \
10+
"git+https://github.com/mne-tools/mne-qt-browser.git" \
1011
\
1112
alphaCSC autoreject bycycle conpy emd fooof meggie \
1213
mne-ari mne-bids-pipeline mne-faster mne-features \

tools/install_pre_requirements.sh

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,7 @@ echo "OpenMEEG"
4444
python -m pip install $STD_ARGS --only-binary ":all:" --extra-index-url "https://test.pypi.org/simple" "openmeeg>=2.6.0.dev4"
4545

4646
echo "nilearn"
47-
# TODO: Revert once settled:
48-
# https://github.com/scikit-learn/scikit-learn/pull/30268#issuecomment-2479701651
49-
python -m pip install $STD_ARGS "git+https://github.com/larsoner/nilearn@sklearn"
47+
python -m pip install $STD_ARGS "git+https://github.com/nilearn/nilearn"
5048

5149
echo "VTK"
5250
python -m pip install $STD_ARGS --only-binary ":all:" --extra-index-url "https://wheels.vtk.org" vtk

0 commit comments

Comments
 (0)