Skip to content

Commit 729bdf3

Browse files
authored
fix sfreq estimation for snirf files (#13184)
1 parent 68f666e commit 729bdf3

File tree

4 files changed

+82
-43
lines changed

4 files changed

+82
-43
lines changed

doc/changes/devel/13184.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix bug with sampling frequency estimation in snirf files, by `Daniel McCloy`_.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
New argument ``sfreq`` to :func:`mne.io.read_raw_snirf`, to allow overriding the sampling frequency estimated from (possibly jittered) sampling periods in the file, by `Daniel McCloy`_.

mne/io/snirf/_snirf.py

Lines changed: 57 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,22 @@
1414
from ..._freesurfer import get_mni_fiducials
1515
from ...annotations import Annotations
1616
from ...transforms import _frame_to_str, apply_trans
17-
from ...utils import _check_fname, _import_h5py, fill_doc, logger, verbose, warn
17+
from ...utils import (
18+
_check_fname,
19+
_import_h5py,
20+
_validate_type,
21+
fill_doc,
22+
logger,
23+
verbose,
24+
warn,
25+
)
1826
from ..base import BaseRaw
1927
from ..nirx.nirx import _convert_fnirs_to_head
2028

2129

2230
@fill_doc
2331
def read_raw_snirf(
24-
fname, optode_frame="unknown", preload=False, verbose=None
32+
fname, optode_frame="unknown", *, sfreq=None, preload=False, verbose=None
2533
) -> "RawSNIRF":
2634
"""Reader for a continuous wave SNIRF data.
2735
@@ -41,6 +49,11 @@ def read_raw_snirf(
4149
in which case the positions are not modified. If a known coordinate
4250
frame is provided (head, meg, mri), then the positions are transformed
4351
in to the Neuromag head coordinate frame (head).
52+
sfreq : float | None
53+
The nominal sampling frequency at which the data were acquired. If ``None``,
54+
will be estimated from the time data in the file.
55+
56+
.. versionadded:: 1.10
4457
%(preload)s
4558
%(verbose)s
4659
@@ -54,7 +67,7 @@ def read_raw_snirf(
5467
--------
5568
mne.io.Raw : Documentation of attributes and methods of RawSNIRF.
5669
"""
57-
return RawSNIRF(fname, optode_frame, preload, verbose)
70+
return RawSNIRF(fname, optode_frame, sfreq=sfreq, preload=preload, verbose=verbose)
5871

5972

6073
def _open(fname):
@@ -74,6 +87,11 @@ class RawSNIRF(BaseRaw):
7487
in which case the positions are not modified. If a known coordinate
7588
frame is provided (head, meg, mri), then the positions are transformed
7689
in to the Neuromag head coordinate frame (head).
90+
sfreq : float | None
91+
The nominal sampling frequency at which the data were acquired. If ``None``,
92+
will be estimated from the time data in the file.
93+
94+
.. versionadded:: 1.10
7795
%(preload)s
7896
%(verbose)s
7997
@@ -83,7 +101,9 @@ class RawSNIRF(BaseRaw):
83101
"""
84102

85103
@verbose
86-
def __init__(self, fname, optode_frame="unknown", preload=False, verbose=None):
104+
def __init__(
105+
self, fname, optode_frame="unknown", *, sfreq=None, preload=False, verbose=None
106+
):
87107
# Must be here due to circular import error
88108
from ...preprocessing.nirs import _validate_nirs_info
89109

@@ -120,7 +140,7 @@ def __init__(self, fname, optode_frame="unknown", preload=False, verbose=None):
120140

121141
last_samps = dat.get("/nirs/data1/dataTimeSeries").shape[0] - 1
122142

123-
sampling_rate = _extract_sampling_rate(dat)
143+
sampling_rate = _extract_sampling_rate(dat, sfreq)
124144

125145
if sampling_rate == 0:
126146
warn("Unable to extract sample rate from SNIRF file.")
@@ -531,49 +551,48 @@ def _get_lengthunit_scaling(length_unit):
531551
)
532552

533553

534-
def _extract_sampling_rate(dat):
554+
def _extract_sampling_rate(dat, user_sfreq):
535555
"""Extract the sample rate from the time field."""
536556
# This is a workaround to provide support for Artinis data.
537557
# It allows for a 1% variation in the sampling times relative
538558
# to the average sampling rate of the file.
539559
MAXIMUM_ALLOWED_SAMPLING_JITTER_PERCENTAGE = 1.0
540560

561+
_validate_type(user_sfreq, ("numeric", None), "sfreq")
541562
time_data = np.array(dat.get("nirs/data1/time"))
542-
sampling_rate = 0
543-
if len(time_data) == 2:
544-
# specified as onset, samplerate
545-
sampling_rate = 1.0 / (time_data[1] - time_data[0])
563+
time_unit = _get_metadata_str(dat, "TimeUnit")
564+
time_unit_scaling = _get_timeunit_scaling(time_unit) # always 1 (s) or 1000 (ms)
565+
if len(time_data) == 2: # special-cased in the snirf standard as (onset, period)
566+
onset, period = time_data
567+
file_sfreq = time_unit_scaling / period
546568
else:
547-
# specified as time points
569+
onset = time_data[0]
548570
periods = np.diff(time_data)
549-
uniq_periods = np.unique(periods.round(decimals=4))
550-
if uniq_periods.size == 1:
551-
# Uniformly sampled data
552-
sampling_rate = 1.0 / uniq_periods.item()
553-
else:
554-
# Hopefully uniformly sampled data with some precision issues.
555-
# This is a workaround to provide support for Artinis data.
556-
mean_period = np.mean(periods)
557-
sampling_rate = 1.0 / mean_period
558-
ideal_times = np.linspace(time_data[0], time_data[-1], time_data.size)
559-
max_jitter = np.max(np.abs(time_data - ideal_times))
560-
percent_jitter = 100.0 * max_jitter / mean_period
561-
msg = (
562-
f"Found jitter of {percent_jitter:3f}% in sample times. Sampling "
563-
f"rate has been set to {sampling_rate:1f}."
571+
sfreqs = time_unit_scaling / periods
572+
file_sfreq = sfreqs.mean() # our best estimate, likely including some jitter
573+
if user_sfreq is not None:
574+
logger.info(f"Setting sampling frequency to user-supplied value: {user_sfreq}")
575+
if not np.allclose(file_sfreq, user_sfreq, rtol=0.01, atol=0):
576+
warn(
577+
f"User-supplied sampling frequency ({user_sfreq} Hz) differs by "
578+
f"{(user_sfreq - file_sfreq) / file_sfreq:.1%} from the frequency "
579+
f"estimated from data in the file ({file_sfreq} Hz)."
564580
)
565-
if percent_jitter > MAXIMUM_ALLOWED_SAMPLING_JITTER_PERCENTAGE:
566-
warn(
567-
f"{msg} Note that MNE-Python does not currently support SNIRF "
568-
"files with non-uniformly-sampled data."
569-
)
570-
else:
571-
logger.info(msg)
572-
time_unit = _get_metadata_str(dat, "TimeUnit")
573-
time_unit_scaling = _get_timeunit_scaling(time_unit)
574-
sampling_rate *= time_unit_scaling
575-
576-
return sampling_rate
581+
sfreq = user_sfreq or file_sfreq # user-passed value overrides value from file
582+
# estimate jitter
583+
if len(time_data) > 2:
584+
ideal_times = onset + np.arange(len(time_data)) / sfreq
585+
max_jitter = np.max(np.abs(time_data - ideal_times))
586+
percent_jitter = 100.0 * max_jitter / periods.mean()
587+
msg = f"Found jitter of {percent_jitter:3f}% in sample times."
588+
if percent_jitter > MAXIMUM_ALLOWED_SAMPLING_JITTER_PERCENTAGE:
589+
warn(
590+
f"{msg} Note that MNE-Python does not currently support SNIRF "
591+
"files with non-uniformly-sampled data."
592+
)
593+
else:
594+
logger.info(msg)
595+
return sfreq
577596

578597

579598
def _get_metadata_str(dat, field):

mne/io/snirf/tests/test_snirf.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import datetime
66
import shutil
7+
from contextlib import nullcontext
78

89
import numpy as np
910
import pytest
@@ -422,7 +423,7 @@ def test_snirf_kernel_hb():
422423
assert raw.copy().pick("hbo")._data.shape == (180, 14)
423424
assert raw.copy().pick("hbr")._data.shape == (180, 14)
424425

425-
assert_allclose(raw.info["sfreq"], 8.257638)
426+
assert_allclose(raw.info["sfreq"], 8.256495)
426427

427428
bad_nans = np.isnan(raw.get_data()).any(axis=1)
428429
assert np.sum(bad_nans) == 20
@@ -434,6 +435,23 @@ def test_snirf_kernel_hb():
434435
assert raw.annotations.description[1] == "StartIti"
435436

436437

438+
@requires_testing_data
439+
@pytest.mark.parametrize(
440+
"sfreq,context",
441+
(
442+
[8.2, nullcontext()], # sfreq estimated from file is 8.256495
443+
[22, pytest.warns(RuntimeWarning, match="User-supplied sampling frequency")],
444+
),
445+
)
446+
def test_user_set_sfreq(sfreq, context):
447+
"""Test manually setting sfreq."""
448+
with context:
449+
# both sfreqs are far enough from true rate to yield >1% jitter
450+
with pytest.warns(RuntimeWarning, match=r"jitter of \d+\.\d*% in sample times"):
451+
raw = read_raw_snirf(kernel_hb, preload=False, sfreq=sfreq)
452+
assert raw.info["sfreq"] == sfreq
453+
454+
437455
@requires_testing_data
438456
@pytest.mark.parametrize(
439457
"fname, boundary_decimal, test_scaling, test_rank",
@@ -535,8 +553,8 @@ def test_sample_rate_jitter(tmp_path):
535553
with h5py.File(new_file, "r+") as f:
536554
orig_time = np.array(f.get("nirs/data1/time"))
537555
acceptable_time_jitter = orig_time.copy()
538-
average_time_diff = np.mean(np.diff(orig_time))
539-
acceptable_time_jitter[-1] += 0.0099 * average_time_diff
556+
mean_period = np.mean(np.diff(orig_time))
557+
acceptable_time_jitter[-1] += 0.0099 * mean_period
540558
del f["nirs/data1/time"]
541559
f.flush()
542560
f.create_dataset("nirs/data1/time", data=acceptable_time_jitter)
@@ -545,11 +563,11 @@ def test_sample_rate_jitter(tmp_path):
545563
lines = "\n".join(line for line in log.getvalue().splitlines() if "jitter" in line)
546564
assert "Found jitter of 0.9" in lines
547565

548-
# Add jitter of 1.01%, which is greater than allowed tolerance
566+
# Add jitter of 1.02%, which is greater than allowed tolerance
549567
with h5py.File(new_file, "r+") as f:
550568
unacceptable_time_jitter = orig_time
551569
unacceptable_time_jitter[-1] = unacceptable_time_jitter[-1] + (
552-
0.0101 * average_time_diff
570+
0.0102 * mean_period
553571
)
554572
del f["nirs/data1/time"]
555573
f.flush()

0 commit comments

Comments
 (0)