Skip to content

Commit 784d381

Browse files
larsonersnwnde
authored andcommitted
ENH: Add Forward.save and hdf5 support (mne-tools#12036)
1 parent 26716d5 commit 784d381

File tree

6 files changed

+69
-8
lines changed

6 files changed

+69
-8
lines changed

doc/changes/devel.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Enhancements
3333
- Add helpful error messages when using methods on empty :class:`mne.Epochs`-objects (:gh:`11306` by `Martin Schulz`_)
3434
- Add inferring EEGLAB files' montage unit automatically based on estimated head radius using :func:`read_raw_eeglab(..., montage_units="auto") <mne.io.read_raw_eeglab>` (:gh:`11925` by `Jack Zhang`_, :gh:`11951` by `Eric Larson`_)
3535
- Add :class:`~mne.time_frequency.EpochsSpectrumArray` and :class:`~mne.time_frequency.SpectrumArray` to support creating power spectra from :class:`NumPy array <numpy.ndarray>` data (:gh:`11803` by `Alex Rockhill`_)
36+
- Add support for writing forward solutions to HDF5 and convenience function :meth:`mne.Forward.save` (:gh:`12036` by `Eric Larson`_)
3637
- Refactored internals of :func:`mne.read_annotations` (:gh:`11964` by `Paul Roujansky`_)
3738
- Enhance :func:`~mne.viz.plot_evoked_field` with a GUI that has controls for time, colormap, and contour lines (:gh:`11942` by `Marijn van Vliet`_)
3839

mne/_fiff/meas_info.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
_check_on_missing,
8181
fill_doc,
8282
_check_fname,
83+
check_fname,
8384
repr_html,
8485
)
8586
from ._digitization import (
@@ -2006,6 +2007,8 @@ def read_info(fname, verbose=None):
20062007
-------
20072008
%(info_not_none)s
20082009
"""
2010+
check_fname(fname, "Info", (".fif", ".fif.gz"))
2011+
fname = _check_fname(fname, must_exist=True, overwrite="read")
20092012
f, tree, _ = fiff_open(fname)
20102013
with f as fid:
20112014
info = read_meas_info(fid, tree)[0]

mne/_fiff/tests/test_meas_info.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,14 @@ def test_read_write_info(tmp_path):
345345
write_info(fname, info)
346346

347347

348+
@testing.requires_testing_data
349+
def test_dir_warning():
350+
"""Test that trying to read a bad filename emits a warning before an error."""
351+
with pytest.raises(OSError, match="directory"):
352+
with pytest.warns(RuntimeWarning, match="foo"):
353+
read_info(ctf_fname)
354+
355+
348356
def test_io_dig_points(tmp_path):
349357
"""Test Writing for dig files."""
350358
dest = tmp_path / "test.txt"

mne/forward/forward.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
_stamp_to_dt,
8282
_on_missing,
8383
repr_html,
84+
_import_h5io_funcs,
8485
)
8586
from ..label import Label
8687

@@ -165,6 +166,18 @@ def copy(self):
165166
"""Copy the Forward instance."""
166167
return Forward(deepcopy(self))
167168

169+
@verbose
170+
def save(self, fname, *, overwrite=False, verbose=None):
171+
"""Save the forward solution.
172+
173+
Parameters
174+
----------
175+
%(fname_fwd)s
176+
%(overwrite)s
177+
%(verbose)s
178+
"""
179+
write_forward_solution(fname, self, overwrite=overwrite)
180+
168181
def _get_src_type_and_ori_for_repr(self):
169182
src_types = np.array([src["type"] for src in self["src"]])
170183

@@ -520,7 +533,8 @@ def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, verbos
520533
Parameters
521534
----------
522535
fname : path-like
523-
The file name, which should end with ``-fwd.fif`` or ``-fwd.fif.gz``.
536+
The file name, which should end with ``-fwd.fif``, ``-fwd.fif.gz``,
537+
``_fwd.fif``, ``_fwd.fif.gz``, ``-fwd.h5``, or ``_fwd.h5``.
524538
include : list, optional
525539
List of names of channels to include. If empty all channels
526540
are included.
@@ -554,11 +568,15 @@ def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, verbos
554568
forward solution with :func:`read_forward_solution`.
555569
"""
556570
check_fname(
557-
fname, "forward", ("-fwd.fif", "-fwd.fif.gz", "_fwd.fif", "_fwd.fif.gz")
571+
fname,
572+
"forward",
573+
("-fwd.fif", "-fwd.fif.gz", "_fwd.fif", "_fwd.fif.gz", "-fwd.h5", "_fwd.h5"),
558574
)
559575
fname = _check_fname(fname=fname, must_exist=True, overwrite="read")
560576
# Open the file, create directory
561577
logger.info("Reading forward solution from %s..." % fname)
578+
if fname.suffix == ".h5":
579+
return _read_forward_hdf5(fname)
562580
f, tree, _ = fiff_open(fname)
563581
with f as fid:
564582
# Find all forward solutions
@@ -861,9 +879,7 @@ def write_forward_solution(fname, fwd, overwrite=False, verbose=None):
861879
862880
Parameters
863881
----------
864-
fname : path-like
865-
File name to save the forward solution to. It should end with
866-
``-fwd.fif`` or ``-fwd.fif.gz``.
882+
%(fname_fwd)s
867883
fwd : Forward
868884
Forward solution.
869885
%(overwrite)s
@@ -889,13 +905,28 @@ def write_forward_solution(fname, fwd, overwrite=False, verbose=None):
889905
forward solution with :func:`read_forward_solution`.
890906
"""
891907
check_fname(
892-
fname, "forward", ("-fwd.fif", "-fwd.fif.gz", "_fwd.fif", "_fwd.fif.gz")
908+
fname,
909+
"forward",
910+
("-fwd.fif", "-fwd.fif.gz", "_fwd.fif", "_fwd.fif.gz", "-fwd.h5", "_fwd.h5"),
893911
)
894912

895913
# check for file existence and expand `~` if present
896914
fname = _check_fname(fname, overwrite)
897-
with start_and_end_file(fname) as fid:
898-
_write_forward_solution(fid, fwd)
915+
if fname.suffix == ".h5":
916+
_write_forward_hdf5(fname, fwd)
917+
else:
918+
with start_and_end_file(fname) as fid:
919+
_write_forward_solution(fid, fwd)
920+
921+
922+
def _write_forward_hdf5(fname, fwd):
923+
_, write_hdf5 = _import_h5io_funcs()
924+
write_hdf5(fname, dict(fwd=fwd), overwrite=True)
925+
926+
927+
def _read_forward_hdf5(fname):
928+
read_hdf5, _ = _import_h5io_funcs()
929+
return Forward(read_hdf5(fname)["fwd"])
899930

900931

901932
def _write_forward_solution(fid, fwd):

mne/forward/tests/test_forward.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,15 @@ def test_io_forward(tmp_path):
197197
fwd_read = read_forward_solution(fname_temp)
198198
assert_forward_allclose(fwd, fwd_read)
199199

200+
h5py = pytest.importorskip("h5py")
201+
pytest.importorskip("h5io")
202+
fname_h5 = fname_temp.with_suffix(".h5")
203+
fwd.save(fname_h5)
204+
with h5py.File(fname_h5, "r"):
205+
pass # just checks for hdf5-ness
206+
fwd_read = read_forward_solution(fname_h5)
207+
assert_forward_allclose(fwd, fwd_read)
208+
200209

201210
@testing.requires_testing_data
202211
def test_apply_forward():

mne/utils/docs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,6 +1694,15 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
16941694
Name of the output file.
16951695
"""
16961696

1697+
docdict[
1698+
"fname_fwd"
1699+
] = """
1700+
fname : path-like
1701+
File name to save the forward solution to. It should end with
1702+
``-fwd.fif`` or ``-fwd.fif.gz`` to save to FIF, or ``-fwd.h5`` to save to
1703+
HDF5.
1704+
"""
1705+
16971706
docdict[
16981707
"fnirs"
16991708
] = """

0 commit comments

Comments
 (0)