diff --git a/doc/changes/devel/13156.newfeature.rst b/doc/changes/devel/13156.newfeature.rst new file mode 100644 index 00000000000..4fe07ebf646 --- /dev/null +++ b/doc/changes/devel/13156.newfeature.rst @@ -0,0 +1 @@ +Added support for file like objects in :func:`read_raw_bdf `, :func:`read_raw_edf ` and :func:`read_raw_gdf `, by `Santi Martínez`_. \ No newline at end of file diff --git a/doc/changes/names.inc b/doc/changes/names.inc index d20931b7a51..1b17a834b44 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -270,6 +270,7 @@ .. _Samuel Louviot: https://github.com/Sam54000 .. _Samuel Powell: https://github.com/samuelpowell .. _Santeri Ruuskanen: https://github.com/ruuskas +.. _Santi Martínez: https://github.com/szz-dvl .. _Sara Sommariva: https://github.com/sarasommariva .. _Sawradip Saha: https://sawradip.github.io/ .. _Scott Huberty: https://orcid.org/0000-0003-2637-031X diff --git a/mne/_edf/open.py b/mne/_edf/open.py new file mode 100644 index 00000000000..2fd97833b29 --- /dev/null +++ b/mne/_edf/open.py @@ -0,0 +1,23 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +# Maybe we can move this one to utils or something like that. +from pathlib import Path + +from mne._fiff.open import _NoCloseRead + +from ..utils import _file_like, _validate_type, logger + + +def _gdf_edf_get_fid(fname, **kwargs): + """Open a EDF/BDF/GDF file with no additional parsing.""" + if _file_like(fname): + logger.debug("Using file-like I/O") + fid = _NoCloseRead(fname) + fid.seek(0) + else: + _validate_type(fname, [Path, str], "fname", extra="or file-like") + logger.debug("Using normal I/O") + fid = open(fname, "rb", **kwargs) # Open in binary mode + return fid diff --git a/mne/fixes.py b/mne/fixes.py index 070d4125d18..2148330fb34 100644 --- a/mne/fixes.py +++ b/mne/fixes.py @@ -16,12 +16,14 @@ # because this module is imported many places (but not always used)! import inspect +import io import operator as operator_module import os import warnings from math import log import numpy as np +import numpy.typing from packaging.version import parse ############################################################################### @@ -733,3 +735,33 @@ def sph_harm_y(n, m, theta, phi, *, diff_n=0): return special.sph_harm_y(n, m, theta, phi, diff_n=diff_n) else: return special.sph_harm(m, n, phi, theta) + + +############################################################################### +# workaround: Numpy won't allow to read from file-like objects with numpy.fromfile, +# we try to use numpy.fromfile, if a blob is used we use numpy.frombuffer to read +# from the file-like object. +def read_from_file_or_buffer( + file: str | bytes | os.PathLike | io.IOBase, + dtype: numpy.typing.DTypeLike = float, + count: int = -1, +): + """numpy.fromfile() wrapper, handling io.BytesIO file-like streams. + + Numpy requires open files to be actual files on disk, i.e., must support + file.fileno(), so it fails with file-like streams such as io.BytesIO(). + + If numpy.fromfile() fails due to no file.fileno() support, this wrapper + reads the required bytes from file and redirects the call to + numpy.frombuffer(). + + See https://github.com/numpy/numpy/issues/2230#issuecomment-949795210 + """ + try: + return np.fromfile(file, dtype=dtype, count=count) + except io.UnsupportedOperation as e: + if not (e.args and e.args[0] == "fileno" and isinstance(file, io.IOBase)): + raise # Nothing I can do about it + dtype = np.dtype(dtype) + buffer = file.read(dtype.itemsize * count) + return np.frombuffer(buffer, dtype=dtype, count=count) diff --git a/mne/io/edf/edf.py b/mne/io/edf/edf.py index 763ef4f91eb..481f5a43364 100644 --- a/mne/io/edf/edf.py +++ b/mne/io/edf/edf.py @@ -7,19 +7,39 @@ import os import re from datetime import date, datetime, timedelta, timezone +from enum import Enum from pathlib import Path import numpy as np from scipy.interpolate import interp1d +from ..._edf.open import _gdf_edf_get_fid from ..._fiff.constants import FIFF from ..._fiff.meas_info import _empty_info, _unique_channel_names from ..._fiff.utils import _blk_read_lims, _mult_cal_one from ...annotations import Annotations from ...filter import resample -from ...utils import _validate_type, fill_doc, logger, verbose, warn +from ...fixes import read_from_file_or_buffer +from ...utils import ( + _check_fname, + _file_like, + _validate_type, + fill_doc, + logger, + verbose, + warn, +) from ..base import BaseRaw, _get_scaling + +class FileType(Enum): + """Enumeration to differentiate files when the extension is not known.""" + + GDF = 1 + EDF = 2 + BDF = 3 + + # common channel type names mapped to internal ch types CH_TYPE_MAPPING = { "EEG": FIFF.FIFFV_EEG_CH, @@ -40,12 +60,17 @@ @fill_doc class RawEDF(BaseRaw): - """Raw object from EDF, EDF+ or BDF file. + """Raw object from EDF, EDF+ file. Parameters ---------- - input_fname : path-like - Path to the EDF, EDF+ or BDF file. + input_fname : path-like | file-like + Path to the EDF, EDF+ file. If a file-like object is provided, + preloading must be used. + + .. versionchanged:: 1.10 + Added support for file-like objects + eog : list or tuple Names of channels or list of indices that should be designated EOG channels. Values should correspond to the electrodes in the file. @@ -88,7 +113,6 @@ class RawEDF(BaseRaw): -------- mne.io.Raw : Documentation of attributes and methods. mne.io.read_raw_edf : Recommended way to read EDF/EDF+ files. - mne.io.read_raw_bdf : Recommended way to read BDF files. Notes ----- @@ -120,7 +144,7 @@ class RawEDF(BaseRaw): >>> events[:, 2] >>= 8 # doctest:+SKIP - TAL channels called 'EDF Annotations' or 'BDF Annotations' are parsed and + TAL channels called 'EDF Annotations' are parsed and extracted annotations are stored in raw.annotations. Use :func:`mne.events_from_annotations` to obtain events from these annotations. @@ -147,8 +171,10 @@ def __init__( *, verbose=None, ): - logger.info(f"Extracting EDF parameters from {input_fname}...") - input_fname = os.path.abspath(input_fname) + if not _file_like(input_fname): + logger.info(f"Extracting EDF parameters from {input_fname}...") + input_fname = os.path.abspath(input_fname) + info, edf_info, orig_units = _get_info( input_fname, stim_channel, @@ -156,11 +182,225 @@ def __init__( misc, exclude, infer_types, + FileType.EDF, + include, + exclude_after_unique, + ) + logger.info("Creating raw.info structure...") + edf_info["blob"] = input_fname if _file_like(input_fname) else None + + _validate_type(units, (str, None, dict), "units") + if units is None: + units = dict() + elif isinstance(units, str): + units = {ch_name: units for ch_name in info["ch_names"]} + + for k, (this_ch, this_unit) in enumerate(orig_units.items()): + if this_ch not in units: + continue + if this_unit not in ("", units[this_ch]): + raise ValueError( + f"Unit for channel {this_ch} is present in the file as " + f"{repr(this_unit)}, cannot overwrite it with the units " + f"argument {repr(units[this_ch])}." + ) + if this_unit == "": + orig_units[this_ch] = units[this_ch] + ch_type = edf_info["ch_types"][k] + scaling = _get_scaling(ch_type.lower(), orig_units[this_ch]) + edf_info["units"][k] /= scaling + + # Raw attributes + last_samps = [edf_info["nsamples"] - 1] + super().__init__( + info, preload, + filenames=[_path_from_fname(input_fname)], + raw_extras=[edf_info], + last_samps=last_samps, + orig_format="int", + orig_units=orig_units, + verbose=verbose, + ) + + # Read annotations from file and set it + if len(edf_info["tal_idx"]) > 0: + # Read TAL data exploiting the header info (no regexp) + idx = np.empty(0, int) + tal_data = self._read_segment_file( + np.empty((0, self.n_times)), + idx, + 0, + 0, + int(self.n_times), + np.ones((len(idx), 1)), + None, + ) + annotations = _read_annotations_edf( + tal_data[0], + ch_names=info["ch_names"], + encoding=encoding, + ) + self.set_annotations(annotations, on_missing="warn") + + def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): + """Read a chunk of raw data.""" + return _read_segment_file( + data, + idx, + fi, + start, + stop, + self._raw_extras[fi], + self.filenames[fi] + if self._raw_extras[fi]["blob"] is None + else self._raw_extras[fi]["blob"], + cals, + mult, + ) + + +def _path_from_fname(fname) -> Path | None: + if isinstance(fname, str | Path): + return Path(fname) + + # Try to get a filename from the file-like object + try: + return Path(fname.name) + except Exception: + return None + + +@fill_doc +class RawBDF(BaseRaw): + """Raw object from BDF file. + + Parameters + ---------- + input_fname : path-like | file-like + Path to the BDF file. If a file-like object is provided, + preloading must be used. + + .. versionchanged:: 1.10 + Added support for file-like objects + + eog : list or tuple + Names of channels or list of indices that should be designated EOG + channels. Values should correspond to the electrodes in the file. + Default is None. + misc : list or tuple + Names of channels or list of indices that should be designated MISC + channels. Values should correspond to the electrodes in the file. + Default is None. + stim_channel : ``'auto'`` | str | list of str | int | list of int + Defaults to ``'auto'``, which means that channels named ``'status'`` or + ``'trigger'`` (case insensitive) are set to STIM. If str (or list of + str), all channels matching the name(s) are set to STIM. If int (or + list of ints), the channels corresponding to the indices are set to + STIM. + exclude : list of str + Channel names to exclude. This can help when reading data with + different sampling rates to avoid unnecessary resampling. + infer_types : bool + If True, try to infer channel types from channel labels. If a channel + label starts with a known type (such as 'EEG') followed by a space and + a name (such as 'Fp1'), the channel type will be set accordingly, and + the channel will be renamed to the original label without the prefix. + For unknown prefixes, the type will be 'EEG' and the name will not be + modified. If False, do not infer types and assume all channels are of + type 'EEG'. + + .. versionadded:: 0.24.1 + include : list of str | str + Channel names to be included. A str is interpreted as a regular + expression. 'exclude' must be empty if include is assigned. + + .. versionadded:: 1.1 + %(preload)s + %(units_edf_bdf_io)s + %(encoding_edf)s + %(exclude_after_unique)s + %(verbose)s + + See Also + -------- + mne.io.Raw : Documentation of attributes and methods. + mne.io.read_raw_bdf : Recommended way to read BDF files. + + Notes + ----- + %(edf_resamp_note)s + + Biosemi devices trigger codes are encoded in 16-bit format, whereas system + codes (CMS in/out-of range, battery low, etc.) are coded in bits 16-23 of + the status channel (see http://www.biosemi.com/faq/trigger_signals.htm). + To retrieve correct event values (bits 1-16), one could do: + + >>> events = mne.find_events(...) # doctest:+SKIP + >>> events[:, 2] &= (2**16 - 1) # doctest:+SKIP + + The above operation can be carried out directly in :func:`mne.find_events` + using the ``mask`` and ``mask_type`` parameters (see + :func:`mne.find_events` for more details). + + It is also possible to retrieve system codes, but no particular effort has + been made to decode these in MNE. In case it is necessary, for instance to + check the CMS bit, the following operation can be carried out: + + >>> cms_bit = 20 # doctest:+SKIP + >>> cms_high = (events[:, 2] & (1 << cms_bit)) != 0 # doctest:+SKIP + + It is worth noting that in some special cases, it may be necessary to shift + event values in order to retrieve correct event triggers. This depends on + the triggering device used to perform the synchronization. For instance, in + some files events need to be shifted by 8 bits: + + >>> events[:, 2] >>= 8 # doctest:+SKIP + + TAL channels called 'BDF Annotations' are parsed and + extracted annotations are stored in raw.annotations. Use + :func:`mne.events_from_annotations` to obtain events from these + annotations. + + If channels named 'status' or 'trigger' are present, they are considered as + STIM channels by default. Use func:`mne.find_events` to parse events + encoded in such analog stim channels. + """ + + @verbose + def __init__( + self, + input_fname, + eog=None, + misc=None, + stim_channel="auto", + exclude=(), + infer_types=False, + preload=False, + include=None, + units=None, + encoding="utf8", + exclude_after_unique=False, + *, + verbose=None, + ): + if not _file_like(input_fname): + logger.info(f"Extracting BDF parameters from {input_fname}...") + input_fname = os.path.abspath(input_fname) + + info, edf_info, orig_units = _get_info( + input_fname, + stim_channel, + eog, + misc, + exclude, + infer_types, + FileType.BDF, include, exclude_after_unique, ) logger.info("Creating raw.info structure...") + edf_info["blob"] = input_fname if _file_like(input_fname) else None _validate_type(units, (str, None, dict), "units") if units is None: @@ -188,7 +428,7 @@ def __init__( super().__init__( info, preload, - filenames=[input_fname], + filenames=[_path_from_fname(input_fname)], raw_extras=[edf_info], last_samps=last_samps, orig_format="int", @@ -225,7 +465,9 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): start, stop, self._raw_extras[fi], - self.filenames[fi], + self.filenames[fi] + if self._raw_extras[fi]["blob"] is None + else self._raw_extras[fi]["blob"], cals, mult, ) @@ -237,8 +479,13 @@ class RawGDF(BaseRaw): Parameters ---------- - input_fname : path-like - Path to the GDF file. + input_fname : path-like | file-like + Path to the GDF file. If a file-like object is provided, + preloading must be used. + + .. versionchanged:: 1.10 + Added support for file-like objects + eog : list or tuple Names of channels or list of indices that should be designated EOG channels. Values should correspond to the electrodes in the file. @@ -289,19 +536,29 @@ def __init__( include=None, verbose=None, ): - logger.info(f"Extracting EDF parameters from {input_fname}...") - input_fname = os.path.abspath(input_fname) + if not _file_like(input_fname): + logger.info(f"Extracting GDF parameters from {input_fname}...") + input_fname = os.path.abspath(input_fname) + info, edf_info, orig_units = _get_info( - input_fname, stim_channel, eog, misc, exclude, True, preload, include + input_fname, + stim_channel, + eog, + misc, + exclude, + True, + FileType.GDF, + include, ) logger.info("Creating raw.info structure...") + edf_info["blob"] = input_fname if _file_like(input_fname) else None # Raw attributes last_samps = [edf_info["nsamples"] - 1] super().__init__( info, preload, - filenames=[input_fname], + filenames=[_path_from_fname(input_fname)], raw_extras=[edf_info], last_samps=last_samps, orig_format="int", @@ -327,7 +584,9 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): start, stop, self._raw_extras[fi], - self.filenames[fi], + self.filenames[fi] + if self._raw_extras[fi]["blob"] is None + else self._raw_extras[fi]["blob"], cals, mult, ) @@ -337,7 +596,7 @@ def _read_ch(fid, subtype, samp, dtype_byte, dtype=None): """Read a number of samples for a single channel.""" # BDF if subtype == "bdf": - ch_data = np.fromfile(fid, dtype=dtype, count=samp * dtype_byte) + ch_data = read_from_file_or_buffer(fid, dtype=dtype, count=samp * dtype_byte) ch_data = ch_data.reshape(-1, 3).astype(INT32) ch_data = (ch_data[:, 0]) + (ch_data[:, 1] << 8) + (ch_data[:, 2] << 16) # 24th bit determines the sign @@ -345,7 +604,7 @@ def _read_ch(fid, subtype, samp, dtype_byte, dtype=None): # GDF data and EDF data else: - ch_data = np.fromfile(fid, dtype=dtype, count=samp) + ch_data = read_from_file_or_buffer(fid, dtype=dtype, count=samp) return ch_data @@ -379,7 +638,8 @@ def _read_segment_file(data, idx, fi, start, stop, raw_extras, filenames, cals, # Otherwise we can end up with e.g. 18,181 chunks for a 20 MB file! # Let's do ~10 MB chunks: n_per = max(10 * 1024 * 1024 // (ch_offsets[-1] * dtype_byte), 1) - with open(filenames, "rb", buffering=0) as fid: + + with _gdf_edf_get_fid(filenames, buffering=0) as fid: # Extract data start_offset = data_offset + block_start_idx * ch_offsets[-1] * dtype_byte @@ -481,13 +741,20 @@ def _read_segment_file(data, idx, fi, start, stop, raw_extras, filenames, cals, @fill_doc -def _read_header(fname, exclude, infer_types, include=None, exclude_after_unique=False): +def _read_header( + fname, + exclude, + infer_types, + file_type, + include=None, + exclude_after_unique=False, +): """Unify EDF, BDF and GDF _read_header call. Parameters ---------- fname : str - Path to the EDF+, BDF, or GDF file. + Path to the EDF+, BDF, or GDF file or file-like object. exclude : list of str | str Channel names to exclude. This can help when reading data with different sampling rates to avoid unnecessary resampling. A str is @@ -509,18 +776,19 @@ def _read_header(fname, exclude, infer_types, include=None, exclude_after_unique ------- (edf_info, orig_units) : tuple """ - ext = os.path.splitext(fname)[1][1:].lower() - logger.info(f"{ext.upper()} file detected") - if ext in ("bdf", "edf"): + if file_type in (FileType.BDF, FileType.EDF): return _read_edf_header( - fname, exclude, infer_types, include, exclude_after_unique + fname, + exclude, + infer_types, + file_type, + include, + exclude_after_unique, ) - elif ext == "gdf": + elif file_type == FileType.GDF: return _read_gdf_header(fname, exclude, include), None else: - raise NotImplementedError( - f"Only GDF, EDF, and BDF files are supported, got {ext}." - ) + raise NotImplementedError("Only GDF, EDF, and BDF files are supported.") def _get_info( @@ -530,7 +798,7 @@ def _get_info( misc, exclude, infer_types, - preload, + file_type, include=None, exclude_after_unique=False, ): @@ -539,7 +807,7 @@ def _get_info( misc = misc if misc is not None else [] edf_info, orig_units = _read_header( - fname, exclude, infer_types, include, exclude_after_unique + fname, exclude, infer_types, file_type, include, exclude_after_unique ) # XXX: `tal_ch_names` to pass to `_check_stim_channel` should be computed @@ -801,12 +1069,17 @@ def _edf_str_num(x): def _read_edf_header( - fname, exclude, infer_types, include=None, exclude_after_unique=False + fname, + exclude, + infer_types, + file_type, + include=None, + exclude_after_unique=False, ): """Read header information from EDF+ or BDF file.""" edf_info = {"events": []} - with open(fname, "rb") as fid: + with _gdf_edf_get_fid(fname) as fid: fid.read(8) # version (unused here) # patient ID @@ -877,14 +1150,20 @@ def _read_edf_header( fid.read(8) # skip the file's measurement time warn("Invalid measurement date encountered in the header.") - header_nbytes = int(_edf_str(fid.read(8))) + try: + header_nbytes = int(_edf_str(fid.read(8))) + except ValueError: + raise ValueError( + f"Bad {'EDF' if file_type is FileType.EDF else 'BDF'} file provided." + ) + # The following 44 bytes sometimes identify the file type, but this is - # not guaranteed. Therefore, we skip this field and use the file - # extension to determine the subtype (EDF or BDF, which differ in the + # not guaranteed. Therefore, we skip this field and use the file_type + # to determine the subtype (EDF or BDF, which differ in the # number of bytes they use for the data records; EDF uses 2 bytes # whereas BDF uses 3 bytes). fid.read(44) - subtype = os.path.splitext(fname)[1][1:].lower() + subtype = file_type n_records = int(_edf_str(fid.read(8))) record_length = float(_edf_str(fid.read(8))) @@ -996,7 +1275,7 @@ def _read_edf_header( physical_max=physical_max, physical_min=physical_min, record_length=record_length, - subtype=subtype, + subtype="bdf" if subtype == FileType.BDF else "edf", tal_idx=tal_idx, ) @@ -1006,7 +1285,9 @@ def _read_edf_header( fid.seek(0, 2) n_bytes = fid.tell() n_data_bytes = n_bytes - header_nbytes - total_samps = n_data_bytes // 3 if subtype == "bdf" else n_data_bytes // 2 + total_samps = ( + n_data_bytes // 3 if subtype == FileType.BDF else n_data_bytes // 2 + ) read_records = total_samps // np.sum(n_samps) if n_records != read_records: warn( @@ -1017,7 +1298,7 @@ def _read_edf_header( edf_info["n_records"] = read_records del n_records - if subtype == "bdf": + if subtype == FileType.BDF: edf_info["dtype_byte"] = 3 # 24-bit (3 byte) integers edf_info["dtype_np"] = UINT8 else: @@ -1074,10 +1355,15 @@ def _read_gdf_header(fname, exclude, include=None): """Read GDF 1.x and GDF 2.x header info.""" edf_info = dict() events = None - with open(fname, "rb") as fid: - version = fid.read(8).decode() - edf_info["type"] = edf_info["subtype"] = version[:3] - edf_info["number"] = float(version[4:]) + + with _gdf_edf_get_fid(fname) as fid: + try: + version = fid.read(8).decode() + edf_info["type"] = edf_info["subtype"] = version[:3] + edf_info["number"] = float(version[4:]) + except ValueError: + raise ValueError("Bad GDF file provided.") + meas_date = None # GDF 1.x @@ -1113,22 +1399,22 @@ def _read_gdf_header(fname, exclude, include=None): except Exception: pass - header_nbytes = np.fromfile(fid, INT64, 1)[0] - meas_id["equipment"] = np.fromfile(fid, UINT8, 8)[0] - meas_id["hospital"] = np.fromfile(fid, UINT8, 8)[0] - meas_id["technician"] = np.fromfile(fid, UINT8, 8)[0] + header_nbytes = read_from_file_or_buffer(fid, INT64, 1)[0] + meas_id["equipment"] = read_from_file_or_buffer(fid, UINT8, 8)[0] + meas_id["hospital"] = read_from_file_or_buffer(fid, UINT8, 8)[0] + meas_id["technician"] = read_from_file_or_buffer(fid, UINT8, 8)[0] fid.seek(20, 1) # 20bytes reserved - n_records = np.fromfile(fid, INT64, 1)[0] + n_records = read_from_file_or_buffer(fid, INT64, 1)[0] # record length in seconds - record_length = np.fromfile(fid, UINT32, 2) + record_length = read_from_file_or_buffer(fid, UINT32, 2) if record_length[0] == 0: record_length[0] = 1.0 warn( "Header information is incorrect for record length. " "Default record length set to 1." ) - nchan = int(np.fromfile(fid, UINT32, 1)[0]) + nchan = int(read_from_file_or_buffer(fid, UINT32, 1)[0]) channels = list(range(nchan)) ch_names = [_edf_str(fid.read(16)).strip() for ch in channels] exclude = _find_exclude_idx(ch_names, exclude, include) @@ -1146,18 +1432,18 @@ def _read_gdf_header(fname, exclude, include=None): edf_info["units"] = np.array(edf_info["units"], float) ch_names = [ch_names[idx] for idx in sel] - physical_min = np.fromfile(fid, FLOAT64, len(channels)) - physical_max = np.fromfile(fid, FLOAT64, len(channels)) - digital_min = np.fromfile(fid, INT64, len(channels)) - digital_max = np.fromfile(fid, INT64, len(channels)) + physical_min = read_from_file_or_buffer(fid, FLOAT64, len(channels)) + physical_max = read_from_file_or_buffer(fid, FLOAT64, len(channels)) + digital_min = read_from_file_or_buffer(fid, INT64, len(channels)) + digital_max = read_from_file_or_buffer(fid, INT64, len(channels)) prefiltering = [_edf_str(fid.read(80)) for ch in channels] highpass, lowpass = _parse_prefilter_string(prefiltering) # n samples per record - n_samps = np.fromfile(fid, INT32, len(channels)) + n_samps = read_from_file_or_buffer(fid, INT32, len(channels)) # channel data type - dtype = np.fromfile(fid, INT32, len(channels)) + dtype = read_from_file_or_buffer(fid, INT32, len(channels)) # total number of bytes for data bytes_tot = np.sum( @@ -1197,19 +1483,21 @@ def _read_gdf_header(fname, exclude, include=None): etp = header_nbytes + n_records * edf_info["bytes_tot"] # skip data to go to event table fid.seek(etp) - etmode = np.fromfile(fid, UINT8, 1)[0] + etmode = read_from_file_or_buffer(fid, UINT8, 1)[0] if etmode in (1, 3): - sr = np.fromfile(fid, UINT8, 3).astype(np.uint32) + sr = read_from_file_or_buffer(fid, UINT8, 3).astype(np.uint32) event_sr = sr[0] for i in range(1, len(sr)): event_sr = event_sr + sr[i] * 2 ** (i * 8) - n_events = np.fromfile(fid, UINT32, 1)[0] - pos = np.fromfile(fid, UINT32, n_events) - 1 # 1-based inds - typ = np.fromfile(fid, UINT16, n_events) + n_events = read_from_file_or_buffer(fid, UINT32, 1)[0] + pos = ( + read_from_file_or_buffer(fid, UINT32, n_events) - 1 + ) # 1-based inds + typ = read_from_file_or_buffer(fid, UINT16, n_events) if etmode == 3: - chn = np.fromfile(fid, UINT16, n_events) - dur = np.fromfile(fid, UINT32, n_events) + chn = read_from_file_or_buffer(fid, UINT16, n_events) + dur = read_from_file_or_buffer(fid, UINT32, n_events) else: chn = np.zeros(n_events, dtype=np.int32) dur = np.ones(n_events, dtype=UINT32) @@ -1234,20 +1522,20 @@ def _read_gdf_header(fname, exclude, include=None): fid.seek(10, 1) # 10bytes reserved # Smoking / Alcohol abuse / drug abuse / medication - sadm = np.fromfile(fid, UINT8, 1)[0] + sadm = read_from_file_or_buffer(fid, UINT8, 1)[0] patient["smoking"] = scale[sadm % 4] patient["alcohol_abuse"] = scale[(sadm >> 2) % 4] patient["drug_abuse"] = scale[(sadm >> 4) % 4] patient["medication"] = scale[(sadm >> 6) % 4] - patient["weight"] = np.fromfile(fid, UINT8, 1)[0] + patient["weight"] = read_from_file_or_buffer(fid, UINT8, 1)[0] if patient["weight"] == 0 or patient["weight"] == 255: patient["weight"] = None - patient["height"] = np.fromfile(fid, UINT8, 1)[0] + patient["height"] = read_from_file_or_buffer(fid, UINT8, 1)[0] if patient["height"] == 0 or patient["height"] == 255: patient["height"] = None # Gender / Handedness / Visual Impairment - ghi = np.fromfile(fid, UINT8, 1)[0] + ghi = read_from_file_or_buffer(fid, UINT8, 1)[0] patient["sex"] = gender[ghi % 4] patient["handedness"] = handedness[(ghi >> 2) % 4] patient["visual"] = scale[(ghi >> 4) % 4] @@ -1255,7 +1543,7 @@ def _read_gdf_header(fname, exclude, include=None): # Recording identification meas_id = {} meas_id["recording_id"] = _edf_str(fid.read(64)).strip() - vhsv = np.fromfile(fid, UINT8, 4) + vhsv = read_from_file_or_buffer(fid, UINT8, 4) loc = {} if vhsv[3] == 0: loc["vertpre"] = 10 * int(vhsv[0] >> 4) + int(vhsv[0] % 16) @@ -1266,12 +1554,16 @@ def _read_gdf_header(fname, exclude, include=None): loc["horzpre"] = 29 loc["size"] = 29 loc["version"] = 0 - loc["latitude"] = float(np.fromfile(fid, UINT32, 1)[0]) / 3600000 - loc["longitude"] = float(np.fromfile(fid, UINT32, 1)[0]) / 3600000 - loc["altitude"] = float(np.fromfile(fid, INT32, 1)[0]) / 100 + loc["latitude"] = ( + float(read_from_file_or_buffer(fid, UINT32, 1)[0]) / 3600000 + ) + loc["longitude"] = ( + float(read_from_file_or_buffer(fid, UINT32, 1)[0]) / 3600000 + ) + loc["altitude"] = float(read_from_file_or_buffer(fid, INT32, 1)[0]) / 100 meas_id["loc"] = loc - meas_date = np.fromfile(fid, UINT64, 1)[0] + meas_date = read_from_file_or_buffer(fid, UINT64, 1)[0] if meas_date != 0: meas_date = datetime(1, 1, 1, tzinfo=timezone.utc) + timedelta( meas_date * pow(2, -32) - 367 @@ -1279,7 +1571,7 @@ def _read_gdf_header(fname, exclude, include=None): else: meas_date = None - birthday = np.fromfile(fid, UINT64, 1).tolist()[0] + birthday = read_from_file_or_buffer(fid, UINT64, 1).tolist()[0] if birthday == 0: birthday = datetime(1, 1, 1, tzinfo=timezone.utc) else: @@ -1298,22 +1590,22 @@ def _read_gdf_header(fname, exclude, include=None): else: patient["age"] = None - header_nbytes = np.fromfile(fid, UINT16, 1)[0] * 256 + header_nbytes = read_from_file_or_buffer(fid, UINT16, 1)[0] * 256 fid.seek(6, 1) # 6 bytes reserved - meas_id["equipment"] = np.fromfile(fid, UINT8, 8) - meas_id["ip"] = np.fromfile(fid, UINT8, 6) - patient["headsize"] = np.fromfile(fid, UINT16, 3) + meas_id["equipment"] = read_from_file_or_buffer(fid, UINT8, 8) + meas_id["ip"] = read_from_file_or_buffer(fid, UINT8, 6) + patient["headsize"] = read_from_file_or_buffer(fid, UINT16, 3) patient["headsize"] = np.asarray(patient["headsize"], np.float32) patient["headsize"] = np.ma.masked_array( patient["headsize"], np.equal(patient["headsize"], 0), None ).filled() - ref = np.fromfile(fid, FLOAT32, 3) - gnd = np.fromfile(fid, FLOAT32, 3) - n_records = np.fromfile(fid, INT64, 1)[0] + ref = read_from_file_or_buffer(fid, FLOAT32, 3) + gnd = read_from_file_or_buffer(fid, FLOAT32, 3) + n_records = read_from_file_or_buffer(fid, INT64, 1)[0] # record length in seconds - record_length = np.fromfile(fid, UINT32, 2) + record_length = read_from_file_or_buffer(fid, UINT32, 2) if record_length[0] == 0: record_length[0] = 1.0 warn( @@ -1321,7 +1613,7 @@ def _read_gdf_header(fname, exclude, include=None): "Default record length set to 1." ) - nchan = int(np.fromfile(fid, UINT16, 1)[0]) + nchan = int(read_from_file_or_buffer(fid, UINT16, 1)[0]) fid.seek(2, 1) # 2bytes reserved # Channels (variable header) @@ -1339,7 +1631,7 @@ def _read_gdf_header(fname, exclude, include=None): - Decimal factors codes: https://sourceforge.net/p/biosig/svn/HEAD/tree/trunk/biosig/doc/DecimalFactors.txt """ # noqa - units = np.fromfile(fid, UINT16, len(channels)).tolist() + units = read_from_file_or_buffer(fid, UINT16, len(channels)).tolist() unitcodes = np.array(units[:]) edf_info["units"] = list() for i, unit in enumerate(units): @@ -1363,32 +1655,36 @@ def _read_gdf_header(fname, exclude, include=None): edf_info["units"] = np.array(edf_info["units"], float) ch_names = [ch_names[idx] for idx in sel] - physical_min = np.fromfile(fid, FLOAT64, len(channels)) - physical_max = np.fromfile(fid, FLOAT64, len(channels)) - digital_min = np.fromfile(fid, FLOAT64, len(channels)) - digital_max = np.fromfile(fid, FLOAT64, len(channels)) + physical_min = read_from_file_or_buffer(fid, FLOAT64, len(channels)) + physical_max = read_from_file_or_buffer(fid, FLOAT64, len(channels)) + digital_min = read_from_file_or_buffer(fid, FLOAT64, len(channels)) + digital_max = read_from_file_or_buffer(fid, FLOAT64, len(channels)) fid.seek(68 * len(channels), 1) # obsolete - lowpass = np.fromfile(fid, FLOAT32, len(channels)) - highpass = np.fromfile(fid, FLOAT32, len(channels)) - notch = np.fromfile(fid, FLOAT32, len(channels)) + lowpass = read_from_file_or_buffer(fid, FLOAT32, len(channels)) + highpass = read_from_file_or_buffer(fid, FLOAT32, len(channels)) + notch = read_from_file_or_buffer(fid, FLOAT32, len(channels)) # number of samples per record - n_samps = np.fromfile(fid, INT32, len(channels)) + n_samps = read_from_file_or_buffer(fid, INT32, len(channels)) # data type - dtype = np.fromfile(fid, INT32, len(channels)) + dtype = read_from_file_or_buffer(fid, INT32, len(channels)) channel = {} - channel["xyz"] = [np.fromfile(fid, FLOAT32, 3)[0] for ch in channels] + channel["xyz"] = [ + read_from_file_or_buffer(fid, FLOAT32, 3)[0] for ch in channels + ] if edf_info["number"] < 2.19: - impedance = np.fromfile(fid, UINT8, len(channels)).astype(float) + impedance = read_from_file_or_buffer(fid, UINT8, len(channels)).astype( + float + ) impedance[impedance == 255] = np.nan channel["impedance"] = pow(2, impedance / 8) fid.seek(19 * len(channels), 1) # reserved else: - tmp = np.fromfile(fid, FLOAT32, 5 * len(channels)) + tmp = read_from_file_or_buffer(fid, FLOAT32, 5 * len(channels)) tmp = tmp[::5] fZ = tmp[:] impedance = tmp[:] @@ -1446,22 +1742,24 @@ def _read_gdf_header(fname, exclude, include=None): etmode = np.fromstring(etmode, UINT8).tolist()[0] if edf_info["number"] < 1.94: - sr = np.fromfile(fid, UINT8, 3) + sr = read_from_file_or_buffer(fid, UINT8, 3) event_sr = sr[0] for i in range(1, len(sr)): event_sr = event_sr + sr[i] * 2 ** (i * 8) - n_events = np.fromfile(fid, UINT32, 1)[0] + n_events = read_from_file_or_buffer(fid, UINT32, 1)[0] else: - ne = np.fromfile(fid, UINT8, 3) + ne = read_from_file_or_buffer(fid, UINT8, 3) n_events = sum(int(ne[i]) << (i * 8) for i in range(len(ne))) - event_sr = np.fromfile(fid, FLOAT32, 1)[0] + event_sr = read_from_file_or_buffer(fid, FLOAT32, 1)[0] - pos = np.fromfile(fid, UINT32, n_events) - 1 # 1-based inds - typ = np.fromfile(fid, UINT16, n_events) + pos = ( + read_from_file_or_buffer(fid, UINT32, n_events) - 1 + ) # 1-based inds + typ = read_from_file_or_buffer(fid, UINT16, n_events) if etmode == 3: - chn = np.fromfile(fid, UINT16, n_events) - dur = np.fromfile(fid, UINT32, n_events) + chn = read_from_file_or_buffer(fid, UINT16, n_events) + dur = read_from_file_or_buffer(fid, UINT32, n_events) else: chn = np.zeros(n_events, dtype=np.uint32) dur = np.ones(n_events, dtype=np.uint32) @@ -1576,6 +1874,20 @@ def _find_tal_idx(ch_names): return tal_channel_idx +def _check_args(input_fname, preload, target_ext): + if not _file_like(input_fname): + input_fname = _check_fname(fname=input_fname, overwrite="read", must_exist=True) + ext = input_fname.suffix[1:].lower() + + if ext != target_ext: + raise NotImplementedError( + f"Only {target_ext.upper()} files are supported, got {ext}." + ) + else: + if not preload: + raise ValueError("preload must be used with file-like objects") + + @fill_doc def read_raw_edf( input_fname, @@ -1597,7 +1909,12 @@ def read_raw_edf( Parameters ---------- input_fname : path-like - Path to the EDF or EDF+ file. + Path to the EDF or EDF+ file or EDF/EDF+ file itself. If a file-like + object is provided, preload must be used. + + .. versionchanged:: 1.10 + Added support for file-like objects + eog : list or tuple Names of channels or list of indices that should be designated EOG channels. Values should correspond to the electrodes in the file. @@ -1693,10 +2010,8 @@ def read_raw_edf( The EDF specification allows storage of subseconds in measurement date. However, this reader currently sets subseconds to 0 by default. """ - input_fname = os.path.abspath(input_fname) - ext = os.path.splitext(input_fname)[1][1:].lower() - if ext != "edf": - raise NotImplementedError(f"Only EDF files are supported, got {ext}.") + _check_args(input_fname, preload, "edf") + return RawEDF( input_fname=input_fname, eog=eog, @@ -1728,13 +2043,18 @@ def read_raw_bdf( exclude_after_unique=False, *, verbose=None, -) -> RawEDF: +) -> RawBDF: """Reader function for BDF files. Parameters ---------- - input_fname : path-like - Path to the BDF file. + input_fname : path-like | file-like + Path to the BDF file of BDF file itself. If a file-like object is + provided, preload must be used. + + .. versionchanged:: 1.10 + Added support for file-like objects + eog : list or tuple Names of channels or list of indices that should be designated EOG channels. Values should correspond to the electrodes in the file. @@ -1827,11 +2147,9 @@ def read_raw_bdf( STIM channels by default. Use func:`mne.find_events` to parse events encoded in such analog stim channels. """ - input_fname = os.path.abspath(input_fname) - ext = os.path.splitext(input_fname)[1][1:].lower() - if ext != "bdf": - raise NotImplementedError(f"Only BDF files are supported, got {ext}.") - return RawEDF( + _check_args(input_fname, preload, "bdf") + + return RawBDF( input_fname=input_fname, eog=eog, misc=misc, @@ -1862,8 +2180,13 @@ def read_raw_gdf( Parameters ---------- - input_fname : path-like - Path to the GDF file. + input_fname : path-like | file-like + Path to the GDF file or GDF file itself. If a file-like object is + provided, preload must be used. + + .. versionchanged:: 1.10 + Added support for file-like objects + eog : list or tuple Names of channels or list of indices that should be designated EOG channels. Values should correspond to the electrodes in the file. @@ -1905,10 +2228,8 @@ def read_raw_gdf( STIM channels by default. Use func:`mne.find_events` to parse events encoded in such analog stim channels. """ - input_fname = os.path.abspath(input_fname) - ext = os.path.splitext(input_fname)[1][1:].lower() - if ext != "gdf": - raise NotImplementedError(f"Only GDF files are supported, got {ext}.") + _check_args(input_fname, preload, "gdf") + return RawGDF( input_fname=input_fname, eog=eog, diff --git a/mne/io/edf/tests/test_edf.py b/mne/io/edf/tests/test_edf.py index ce671ca7e81..1760081bac4 100644 --- a/mne/io/edf/tests/test_edf.py +++ b/mne/io/edf/tests/test_edf.py @@ -5,6 +5,7 @@ import datetime from contextlib import nullcontext from functools import partial +from io import BytesIO from pathlib import Path import numpy as np @@ -174,7 +175,7 @@ def test_bdf_data(): test_scaling=test_scaling, ) assert len(raw_py.ch_names) == 71 - assert "RawEDF" in repr(raw_py) + assert "RawBDF" in repr(raw_py) picks = pick_types(raw_py.info, meg=False, eeg=True, exclude="bads") data_py, _ = raw_py[picks] @@ -958,11 +959,17 @@ def test_degenerate(): read_raw_edf, read_raw_bdf, read_raw_gdf, - partial(_read_header, exclude=(), infer_types=False), ): with pytest.raises(NotImplementedError, match="Only.*txt.*"): func(edf_txt_stim_channel_path) + with pytest.raises( + NotImplementedError, match="Only GDF, EDF, and BDF files are supported." + ): + partial(_read_header, exclude=(), infer_types=False, file_type=4)( + edf_txt_stim_channel_path + ) + def test_exclude(): """Test exclude parameter.""" @@ -1208,3 +1215,49 @@ def test_anonymization(): assert bday == datetime.date(1967, 10, 9) raw.anonymize() assert raw.info["subject_info"]["birthday"] != bday + + +@pytest.mark.filterwarnings( + "ignore:Invalid measurement date encountered in the header." +) +@testing.requires_testing_data +def test_bdf_read_from_bad_file_like(): + """Test that RawEDF is NOT able to read from file-like objects for non BDF files.""" + with pytest.raises(Exception, match="Bad BDF file provided."): + with open(edf_txt_stim_channel_path, "rb") as blob: + read_raw_bdf(BytesIO(blob.read()), preload=True) + + +@testing.requires_testing_data +def test_bdf_read_from_file_like(): + """Test that RawEDF is able to read from file-like objects for BDF files.""" + with open(bdf_path, "rb") as blob: + raw = read_raw_bdf(BytesIO(blob.read()), preload=True) + assert len(raw.ch_names) == 73 + + +@pytest.mark.filterwarnings( + "ignore:Invalid measurement date encountered in the header." +) +@testing.requires_testing_data +def test_edf_read_from_bad_file_like(): + """Test that RawEDF is NOT able to read from file-like objects for non EDF files.""" + with pytest.raises(Exception, match="Bad EDF file provided."): + with open(edf_txt_stim_channel_path, "rb") as blob: + read_raw_edf(BytesIO(blob.read()), preload=True) + + +@testing.requires_testing_data +def test_edf_read_from_file_like(): + """Test that RawEDF is able to read from file-like objects for EDF files.""" + with open(edf_path, "rb") as blob: + raw = read_raw_edf(BytesIO(blob.read()), preload=True) + channels = [ + *[f"{prefix}{num}" for prefix in "ABCDEFGH" for num in range(1, 17)], + *[f"I{num}" for num in range(1, 9)], + "Ergo-Left", + "Ergo-Right", + "Status", + ] + + assert raw.ch_names == channels diff --git a/mne/io/edf/tests/test_gdf.py b/mne/io/edf/tests/test_gdf.py index 1dc5dc00a47..92b28cfa2e0 100644 --- a/mne/io/edf/tests/test_gdf.py +++ b/mne/io/edf/tests/test_gdf.py @@ -4,8 +4,10 @@ import shutil from datetime import date, datetime, timedelta, timezone +from io import BytesIO import numpy as np +import pytest import scipy.io as sio from numpy.testing import assert_array_almost_equal, assert_array_equal, assert_equal @@ -181,3 +183,36 @@ def test_gdf_include(): gdf1_path.with_name(gdf1_path.name + ".gdf"), include=("FP1", "O1") ) assert sorted(raw.ch_names) == ["FP1", "O1"] + + +@testing.requires_testing_data +def test_gdf_read_from_file_like(): + """Test that RawGDF is able to read from file-like objects for GDF files.""" + with open(gdf1_path.with_name(gdf1_path.name + ".gdf"), "rb") as blob: + raw = read_raw_gdf(BytesIO(blob.read()), preload=True) + channels = [ + "FP1", + "FP2", + "F5", + "AFz", + "F6", + "T7", + "Cz", + "T8", + "P7", + "P3", + "Pz", + "P4", + "P8", + "O1", + "Oz", + "O2", + ] + + assert raw.ch_names == channels + + +def test_gdf_read_from_bad_file_like(): + """Test that RawGDF is NOT able to read from file-like objects for non GDF files.""" + with pytest.raises(Exception, match="Bad GDF file provided."): + read_raw_gdf(BytesIO(), preload=True)