Skip to content

Commit c6956e7

Browse files
authored
[MRG] Rank refactor part2 (#5870)
* start moving things to rank.py * more * more * more * more * fix * FIX: Minor fixes
1 parent d5ae1c8 commit c6956e7

File tree

13 files changed

+471
-489
lines changed

13 files changed

+471
-489
lines changed

mne/cov.py

Lines changed: 10 additions & 178 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,18 @@
1414
import numpy as np
1515
from scipy import linalg, sparse
1616

17+
from .io.meas_info import _simplify_info
1718
from .io.write import start_file, end_file
1819
from .io.proj import (make_projector, _proj_equal, activate_proj,
19-
_needs_eeg_average_ref_proj, _check_projs,
20+
_check_projs, _needs_eeg_average_ref_proj,
2021
_has_eeg_average_ref_proj)
2122
from .io import fiff_open
2223
from .io.pick import (pick_types, pick_channels_cov, pick_channels, pick_info,
2324
_picks_by_type, _pick_data_channels,
2425
_DATA_CH_TYPES_SPLIT)
2526

2627
from .io.constants import FIFF
27-
from .io.meas_info import read_bad_channels, _simplify_info, create_info
28+
from .io.meas_info import read_bad_channels, create_info
2829
from .io.proj import _read_proj, _write_proj
2930
from .io.tag import find_tag
3031
from .io.tree import dir_tree_find
@@ -33,10 +34,13 @@
3334
from .defaults import _handle_default
3435
from .epochs import Epochs
3536
from .event import make_fixed_length_events
36-
from .utils import (check_fname, logger, verbose, estimate_rank,
37-
_compute_row_norms, check_version, _time_mask, warn,
38-
copy_function_doc_to_method_doc, _pl)
37+
from .utils import (check_fname, logger, verbose,
38+
check_version, _time_mask, warn,
39+
copy_function_doc_to_method_doc, _pl,
40+
_undo_scaling_cov, _undo_scaling_array,
41+
_apply_scaling_array)
3942
from . import viz
43+
from .rank import _estimate_rank_meeg_cov
4044

4145
from .fixes import BaseEstimator, EmpiricalCovariance, _logdet
4246

@@ -629,7 +633,7 @@ def compute_covariance(epochs, keep_sample_mean=True, tmin=None, tmax=None,
629633
630634
['shrunk', 'diagonal_fixed', 'empirical', 'factor_analysis']
631635
632-
``'factor_analysis'`` is removed when `rank` is not 'full'.
636+
``'factor_analysis'`` is removed when ``rank`` is not 'full'.
633637
The ``'auto'`` mode is not recommended if there are many
634638
segments of data, since computation can take a long time.
635639
@@ -1990,175 +1994,3 @@ def _write_cov(fid, cov):
19901994

19911995
# Done!
19921996
end_block(fid, FIFF.FIFFB_MNE_COV)
1993-
1994-
1995-
def _apply_scaling_array(data, picks_list, scalings):
1996-
"""Scale data type-dependently for estimation."""
1997-
scalings = _check_scaling_inputs(data, picks_list, scalings)
1998-
if isinstance(scalings, dict):
1999-
picks_dict = dict(picks_list)
2000-
scalings = [(picks_dict[k], v) for k, v in scalings.items()
2001-
if k in picks_dict]
2002-
for idx, scaling in scalings:
2003-
data[idx, :] *= scaling # F - order
2004-
else:
2005-
data *= scalings[:, np.newaxis] # F - order
2006-
2007-
2008-
def _invert_scalings(scalings):
2009-
if isinstance(scalings, dict):
2010-
scalings = dict((k, 1. / v) for k, v in scalings.items())
2011-
elif isinstance(scalings, np.ndarray):
2012-
scalings = 1. / scalings
2013-
return scalings
2014-
2015-
2016-
def _undo_scaling_array(data, picks_list, scalings):
2017-
scalings = _invert_scalings(_check_scaling_inputs(data, picks_list,
2018-
scalings))
2019-
return _apply_scaling_array(data, picks_list, scalings)
2020-
2021-
2022-
def _apply_scaling_cov(data, picks_list, scalings):
2023-
"""Scale resulting data after estimation."""
2024-
scalings = _check_scaling_inputs(data, picks_list, scalings)
2025-
scales = None
2026-
if isinstance(scalings, dict):
2027-
n_channels = len(data)
2028-
covinds = list(zip(*picks_list))[1]
2029-
assert len(data) == sum(len(k) for k in covinds)
2030-
assert list(sorted(np.concatenate(covinds))) == list(range(len(data)))
2031-
scales = np.zeros(n_channels)
2032-
for ch_t, idx in picks_list:
2033-
scales[idx] = scalings[ch_t]
2034-
elif isinstance(scalings, np.ndarray):
2035-
if len(scalings) != len(data):
2036-
raise ValueError('Scaling factors and data are of incompatible '
2037-
'shape')
2038-
scales = scalings
2039-
elif scalings is None:
2040-
pass
2041-
else:
2042-
raise RuntimeError('Arff...')
2043-
if scales is not None:
2044-
assert np.sum(scales == 0.) == 0
2045-
data *= (scales[None, :] * scales[:, None])
2046-
2047-
2048-
def _undo_scaling_cov(data, picks_list, scalings):
2049-
scalings = _invert_scalings(_check_scaling_inputs(data, picks_list,
2050-
scalings))
2051-
return _apply_scaling_cov(data, picks_list, scalings)
2052-
2053-
2054-
def _check_scaling_inputs(data, picks_list, scalings):
2055-
"""Aux function."""
2056-
rescale_dict_ = dict(mag=1e15, grad=1e13, eeg=1e6)
2057-
2058-
scalings_ = None
2059-
if isinstance(scalings, str) and scalings == 'norm':
2060-
scalings_ = 1. / _compute_row_norms(data)
2061-
elif isinstance(scalings, dict):
2062-
rescale_dict_.update(scalings)
2063-
scalings_ = rescale_dict_
2064-
elif isinstance(scalings, np.ndarray):
2065-
scalings_ = scalings
2066-
elif scalings is None:
2067-
pass
2068-
else:
2069-
raise NotImplementedError("No way! That's not a rescaling "
2070-
'option: %s' % scalings)
2071-
return scalings_
2072-
2073-
2074-
def _estimate_rank_meeg_signals(data, info, scalings, tol='auto',
2075-
return_singular=False):
2076-
"""Estimate rank for M/EEG data.
2077-
2078-
Parameters
2079-
----------
2080-
data : np.ndarray of float, shape(n_channels, n_samples)
2081-
The M/EEG signals.
2082-
info : Info
2083-
The measurement info.
2084-
scalings : dict | 'norm' | np.ndarray | None
2085-
The rescaling method to be applied. If dict, it will override the
2086-
following default dict:
2087-
2088-
dict(mag=1e15, grad=1e13, eeg=1e6)
2089-
2090-
If 'norm' data will be scaled by channel-wise norms. If array,
2091-
pre-specified norms will be used. If None, no scaling will be applied.
2092-
tol : float | str
2093-
Tolerance. See ``estimate_rank``.
2094-
return_singular : bool
2095-
If True, also return the singular values that were used
2096-
to determine the rank.
2097-
2098-
Returns
2099-
-------
2100-
rank : int
2101-
Estimated rank of the data.
2102-
s : array
2103-
If return_singular is True, the singular values that were
2104-
thresholded to determine the rank are also returned.
2105-
"""
2106-
picks_list = _picks_by_type(info)
2107-
_apply_scaling_array(data, picks_list, scalings)
2108-
if data.shape[1] < data.shape[0]:
2109-
ValueError("You've got fewer samples than channels, your "
2110-
"rank estimate might be inaccurate.")
2111-
out = estimate_rank(data, tol=tol, norm=False,
2112-
return_singular=return_singular)
2113-
rank = out[0] if isinstance(out, tuple) else out
2114-
ch_type = ' + '.join(list(zip(*picks_list))[0])
2115-
logger.info('estimated rank (%s): %d' % (ch_type, rank))
2116-
_undo_scaling_array(data, picks_list, scalings)
2117-
return out
2118-
2119-
2120-
def _estimate_rank_meeg_cov(data, info, scalings, tol='auto',
2121-
return_singular=False):
2122-
"""Estimate rank of M/EEG covariance data, given the covariance.
2123-
2124-
Parameters
2125-
----------
2126-
data : np.ndarray of float, shape (n_channels, n_channels)
2127-
The M/EEG covariance.
2128-
info : Info
2129-
The measurement info.
2130-
scalings : dict | 'norm' | np.ndarray | None
2131-
The rescaling method to be applied. If dict, it will override the
2132-
following default dict:
2133-
2134-
dict(mag=1e12, grad=1e11, eeg=1e5)
2135-
2136-
If 'norm' data will be scaled by channel-wise norms. If array,
2137-
pre-specified norms will be used. If None, no scaling will be applied.
2138-
tol : float | str
2139-
Tolerance. See ``estimate_rank``.
2140-
return_singular : bool
2141-
If True, also return the singular values that were used
2142-
to determine the rank.
2143-
2144-
Returns
2145-
-------
2146-
rank : int
2147-
Estimated rank of the data.
2148-
s : array
2149-
If return_singular is True, the singular values that were
2150-
thresholded to determine the rank are also returned.
2151-
"""
2152-
picks_list = _picks_by_type(info)
2153-
scalings = _handle_default('scalings_cov_rank', scalings)
2154-
_apply_scaling_cov(data, picks_list, scalings)
2155-
if data.shape[1] < data.shape[0]:
2156-
ValueError("You've got fewer samples than channels, your "
2157-
"rank estimate might be inaccurate.")
2158-
out = estimate_rank(data, tol=tol, norm=False,
2159-
return_singular=return_singular)
2160-
rank = out[0] if isinstance(out, tuple) else out
2161-
ch_type = ' + '.join(list(zip(*picks_list))[0])
2162-
logger.info('estimated rank (%s): %d' % (ch_type, rank))
2163-
_undo_scaling_cov(data, picks_list, scalings)
2164-
return out

mne/io/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1959,7 +1959,7 @@ def estimate_rank(self, tstart=0.0, tstop=30.0, tol=1e-4,
19591959
19601960
Bad channels will be excluded from calculations.
19611961
"""
1962-
from ..cov import _estimate_rank_meeg_signals
1962+
from ..rank import _estimate_rank_meeg_signals
19631963

19641964
start = max(0, self.time_as_index(tstart)[0])
19651965
if tstop is None:

mne/io/fiff/tests/test_raw_fiff.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from copy import deepcopy
88
from functools import partial
9-
import itertools as itt
109
import os.path as op
1110
import pickle
1211
import sys
@@ -26,8 +25,6 @@
2625
pick_info)
2726
from mne.utils import (_TempDir, requires_pandas, object_diff,
2827
requires_mne, run_subprocess, run_tests_if_main)
29-
from mne.io.proc_history import _get_rank_sss
30-
from mne.io.pick import _picks_by_type
3128
from mne.annotations import Annotations
3229

3330
testing_path = testing.data_path(download=False)
@@ -202,37 +199,6 @@ def test_copy_append():
202199
assert_equal(data.shape[1], 2 * raw._data.shape[1])
203200

204201

205-
@pytest.mark.slowtest
206-
@testing.requires_testing_data
207-
def test_rank_estimation():
208-
"""Test raw rank estimation."""
209-
iter_tests = itt.product(
210-
[fif_fname, hp_fif_fname], # sss
211-
['norm', dict(mag=1e11, grad=1e9, eeg=1e5)]
212-
)
213-
for fname, scalings in iter_tests:
214-
raw = read_raw_fif(fname).crop(0, 4.).load_data()
215-
(_, picks_meg), (_, picks_eeg) = _picks_by_type(raw.info,
216-
meg_combined=True)
217-
n_meg = len(picks_meg)
218-
n_eeg = len(picks_eeg)
219-
220-
if len(raw.info['proc_history']) == 0:
221-
expected_rank = n_meg + n_eeg
222-
else:
223-
expected_rank = _get_rank_sss(raw.info) + n_eeg
224-
assert_array_equal(raw.estimate_rank(scalings=scalings), expected_rank)
225-
assert_array_equal(raw.estimate_rank(picks=picks_eeg,
226-
scalings=scalings), n_eeg)
227-
if 'sss' in fname:
228-
raw.add_proj(compute_proj_raw(raw))
229-
raw.apply_proj()
230-
n_proj = len(raw.info['projs'])
231-
assert_array_equal(raw.estimate_rank(tstart=0, tstop=3.,
232-
scalings=scalings),
233-
expected_rank - (0 if 'sss' in fname else n_proj))
234-
235-
236202
@testing.requires_testing_data
237203
def test_output_formats():
238204
"""Test saving and loading raw data using multiple formats."""

mne/io/proc_history.py

Lines changed: 1 addition & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
write_float_sparse, write_id)
1616
from .tag import find_tag
1717
from .constants import FIFF
18-
from ..utils import warn, logger
18+
from ..utils import warn
1919

2020
_proc_keys = ['parent_file_id', 'block_id', 'parent_block_id',
2121
'date', 'experimenter', 'creator']
@@ -295,52 +295,3 @@ def _write_maxfilter_record(fid, record):
295295
if key in sss_cal:
296296
writer(fid, id_, sss_cal[key])
297297
end_block(fid, FIFF.FIFFB_SSS_CAL)
298-
299-
300-
def _get_sss_rank(sss):
301-
"""Get SSS rank."""
302-
inside = sss['sss_info']['in_order']
303-
nfree = (inside + 1) ** 2 - 1
304-
nfree -= (len(sss['sss_info']['components'][:nfree]) -
305-
sss['sss_info']['components'][:nfree].sum())
306-
return nfree
307-
308-
309-
def _get_rank_sss(inst):
310-
"""Look up rank from SSS data.
311-
312-
.. note::
313-
Throws an error if SSS has not been applied.
314-
315-
Parameters
316-
----------
317-
inst : instance of Raw, Epochs or Evoked, or Info
318-
Any MNE object with an .info attribute
319-
320-
Returns
321-
-------
322-
rank : int
323-
The numerical rank as predicted by the number of SSS
324-
components.
325-
"""
326-
from .meas_info import Info
327-
info = inst if isinstance(inst, Info) else inst.info
328-
del inst
329-
330-
max_infos = list()
331-
for proc_info in info.get('proc_history', list()):
332-
max_info = proc_info.get('max_info')
333-
if max_info is not None:
334-
if len(max_info) > 0:
335-
max_infos.append(max_info)
336-
elif len(max_info) > 1:
337-
logger.info('found multiple SSS records. Using the first.')
338-
elif len(max_info) == 0:
339-
raise ValueError(
340-
'Did not find any SSS record. You should use data-based '
341-
'rank estimate instead')
342-
if len(max_infos) > 0:
343-
max_info = max_infos[0]
344-
else:
345-
raise ValueError('There is no `max_info` here. Sorry.')
346-
return _get_sss_rank(max_info)

mne/io/tests/test_proc_history.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
import numpy as np
88
from numpy.testing import assert_array_equal
99

10-
from mne.io import read_info, read_raw_fif
10+
from mne.io import read_info
1111
from mne.io.constants import FIFF
12-
from mne.io.proc_history import _get_rank_sss
1312

1413
base_dir = op.join(op.dirname(__file__), 'data')
1514
raw_fname = op.join(base_dir, 'test_chpi_raw_sss.fif')
@@ -37,12 +36,3 @@ def test_maxfilter_io():
3736
assert mf['sss_cal']['cal_chans'].shape == (306, 2)
3837
vv_coils = [v for k, v in FIFF.items() if 'FIFFV_COIL_VV' in k]
3938
assert all(k in vv_coils for k in set(mf['sss_cal']['cal_chans'][:, 1]))
40-
41-
42-
def test_maxfilter_get_rank():
43-
"""Test maxfilter rank lookup."""
44-
raw = read_raw_fif(raw_fname)
45-
mf = raw.info['proc_history'][0]['max_info']
46-
rank1 = mf['sss_info']['nfree']
47-
rank2 = _get_rank_sss(raw)
48-
assert rank1 == rank2

mne/preprocessing/tests/test_maxwell.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
from mne.cov import _estimate_rank_meeg_cov
1919
from mne.datasets import testing
2020
from mne.forward import use_coil_def
21-
from mne.io import (read_raw_fif, proc_history, read_info, read_raw_bti,
22-
read_raw_kit, BaseRaw)
21+
from mne.io import read_raw_fif, read_info, read_raw_bti, read_raw_kit, BaseRaw
2322
from mne.preprocessing.maxwell import (
2423
maxwell_filter, _get_n_moments, _sss_basis_basic, _sh_complex_to_real,
2524
_sh_real_to_complex, _sh_negate, _bases_complex_to_real, _trans_sss_basis,
2625
_bases_real_to_complex, _prep_mf_coils)
26+
from mne.rank import _get_sss_rank
2727
from mne.tests.common import assert_meg_snr
2828
from mne.utils import (_TempDir, run_tests_if_main, catch_logging,
2929
requires_version, object_diff, buggy_mkl_svd)
@@ -439,8 +439,7 @@ def test_basic():
439439

440440
# Check against SSS functions from proc_history
441441
sss_info = raw_sss.info['proc_history'][0]['max_info']
442-
assert_equal(_get_n_moments(int_order),
443-
proc_history._get_sss_rank(sss_info))
442+
assert_equal(_get_n_moments(int_order), _get_sss_rank(sss_info))
444443

445444
# Degenerate cases
446445
pytest.raises(ValueError, maxwell_filter, raw, coord_frame='foo')

0 commit comments

Comments
 (0)