|
41 | 41 | pick_info,
|
42 | 42 | pick_types,
|
43 | 43 | )
|
44 |
| -from .._fiff.proj import setup_proj |
| 44 | +from .._fiff.proj import _has_eeg_average_ref_proj, setup_proj |
45 | 45 | from .._fiff.reference import add_reference_channels, set_eeg_reference
|
46 | 46 | from .._fiff.tag import _rename_list
|
47 | 47 | from ..bem import _check_origin
|
@@ -960,6 +960,162 @@ def interpolate_bads(
|
960 | 960 |
|
961 | 961 | return self
|
962 | 962 |
|
| 963 | + def interpolate_to(self, sensors, origin="auto", method="spline", reg=0.0): |
| 964 | + """Interpolate EEG data onto a new montage. |
| 965 | +
|
| 966 | + .. warning:: |
| 967 | + Be careful, only EEG channels are interpolated. Other channel types are |
| 968 | + not interpolated. |
| 969 | +
|
| 970 | + Parameters |
| 971 | + ---------- |
| 972 | + sensors : DigMontage |
| 973 | + The target montage containing channel positions to interpolate onto. |
| 974 | + origin : array-like, shape (3,) | str |
| 975 | + Origin of the sphere in the head coordinate frame and in meters. |
| 976 | + Can be ``'auto'`` (default), which means a head-digitization-based |
| 977 | + origin fit. |
| 978 | + method : str |
| 979 | + Method to use for EEG channels. |
| 980 | + Supported methods are 'spline' (default) and 'MNE'. |
| 981 | + reg : float |
| 982 | + The regularization parameter for the interpolation method |
| 983 | + (only used when the method is 'spline'). |
| 984 | +
|
| 985 | + Returns |
| 986 | + ------- |
| 987 | + inst : instance of Raw, Epochs, or Evoked |
| 988 | + The instance with updated channel locations and data. |
| 989 | +
|
| 990 | + Notes |
| 991 | + ----- |
| 992 | + This method is useful for standardizing EEG layouts across datasets. |
| 993 | + However, some attributes may be lost after interpolation. |
| 994 | +
|
| 995 | + .. versionadded:: 1.10.0 |
| 996 | + """ |
| 997 | + from ..epochs import BaseEpochs, EpochsArray |
| 998 | + from ..evoked import Evoked, EvokedArray |
| 999 | + from ..forward._field_interpolation import _map_meg_or_eeg_channels |
| 1000 | + from ..io import RawArray |
| 1001 | + from ..io.base import BaseRaw |
| 1002 | + from .interpolation import _make_interpolation_matrix |
| 1003 | + from .montage import DigMontage |
| 1004 | + |
| 1005 | + # Check that the method option is valid. |
| 1006 | + _check_option("method", method, ["spline", "MNE"]) |
| 1007 | + _validate_type(sensors, DigMontage, "sensors") |
| 1008 | + |
| 1009 | + # Get target positions from the montage |
| 1010 | + ch_pos = sensors.get_positions().get("ch_pos", {}) |
| 1011 | + target_ch_names = list(ch_pos.keys()) |
| 1012 | + if not target_ch_names: |
| 1013 | + raise ValueError( |
| 1014 | + "The provided sensors configuration has no channel positions." |
| 1015 | + ) |
| 1016 | + |
| 1017 | + # Get original channel order |
| 1018 | + orig_names = self.info["ch_names"] |
| 1019 | + |
| 1020 | + # Identify EEG channel |
| 1021 | + picks_good_eeg = pick_types(self.info, meg=False, eeg=True, exclude="bads") |
| 1022 | + if len(picks_good_eeg) == 0: |
| 1023 | + raise ValueError("No good EEG channels available for interpolation.") |
| 1024 | + # Also get the full list of EEG channel indices (including bad channels) |
| 1025 | + picks_remove_eeg = pick_types(self.info, meg=False, eeg=True, exclude=[]) |
| 1026 | + eeg_names_orig = [orig_names[i] for i in picks_remove_eeg] |
| 1027 | + |
| 1028 | + # Identify non-EEG channels in original order |
| 1029 | + non_eeg_names_ordered = [ch for ch in orig_names if ch not in eeg_names_orig] |
| 1030 | + |
| 1031 | + # Create destination info for new EEG channels |
| 1032 | + sfreq = self.info["sfreq"] |
| 1033 | + info_interp = create_info( |
| 1034 | + ch_names=target_ch_names, |
| 1035 | + sfreq=sfreq, |
| 1036 | + ch_types=["eeg"] * len(target_ch_names), |
| 1037 | + ) |
| 1038 | + info_interp.set_montage(sensors) |
| 1039 | + info_interp["bads"] = [ch for ch in self.info["bads"] if ch in target_ch_names] |
| 1040 | + # Do not assign "projs" directly. |
| 1041 | + |
| 1042 | + # Compute the interpolation mapping |
| 1043 | + if method == "spline": |
| 1044 | + origin_val = _check_origin(origin, self.info) |
| 1045 | + pos_from = self.info._get_channel_positions(picks_good_eeg) - origin_val |
| 1046 | + pos_to = np.stack(list(ch_pos.values()), axis=0) |
| 1047 | + |
| 1048 | + def _check_pos_sphere(pos): |
| 1049 | + d = np.linalg.norm(pos, axis=-1) |
| 1050 | + d_norm = np.mean(d / np.mean(d)) |
| 1051 | + if np.abs(1.0 - d_norm) > 0.1: |
| 1052 | + warn("Your spherical fit is poor; interpolation may be inaccurate.") |
| 1053 | + |
| 1054 | + _check_pos_sphere(pos_from) |
| 1055 | + _check_pos_sphere(pos_to) |
| 1056 | + mapping = _make_interpolation_matrix(pos_from, pos_to, alpha=reg) |
| 1057 | + |
| 1058 | + else: |
| 1059 | + assert method == "MNE" |
| 1060 | + info_eeg = pick_info(self.info, picks_good_eeg) |
| 1061 | + # If the original info has an average EEG reference projector but |
| 1062 | + # the destination info does not, |
| 1063 | + # update info_interp via a temporary RawArray. |
| 1064 | + if _has_eeg_average_ref_proj(self.info) and not _has_eeg_average_ref_proj( |
| 1065 | + info_interp |
| 1066 | + ): |
| 1067 | + # Create dummy data: shape (n_channels, 1) |
| 1068 | + temp_data = np.zeros((len(info_interp["ch_names"]), 1)) |
| 1069 | + temp_raw = RawArray(temp_data, info_interp, first_samp=0) |
| 1070 | + # Using the public API, add an average reference projector. |
| 1071 | + temp_raw.set_eeg_reference( |
| 1072 | + ref_channels="average", projection=True, verbose=False |
| 1073 | + ) |
| 1074 | + # Extract the updated info. |
| 1075 | + info_interp = temp_raw.info |
| 1076 | + mapping = _map_meg_or_eeg_channels( |
| 1077 | + info_eeg, info_interp, mode="accurate", origin=origin |
| 1078 | + ) |
| 1079 | + |
| 1080 | + # Interpolate EEG data |
| 1081 | + data_good = self.get_data(picks=picks_good_eeg) |
| 1082 | + data_interp = mapping @ data_good |
| 1083 | + |
| 1084 | + # Create a new instance for the interpolated EEG channels |
| 1085 | + # TODO: Creating a new instance leads to a loss of information. |
| 1086 | + # We should consider updating the existing instance in the future |
| 1087 | + # by 1) drop channels, 2) add channels, 3) re-order channels. |
| 1088 | + if isinstance(self, BaseRaw): |
| 1089 | + inst_interp = RawArray(data_interp, info_interp, first_samp=self.first_samp) |
| 1090 | + elif isinstance(self, BaseEpochs): |
| 1091 | + inst_interp = EpochsArray(data_interp, info_interp) |
| 1092 | + else: |
| 1093 | + assert isinstance(self, Evoked) |
| 1094 | + inst_interp = EvokedArray(data_interp, info_interp) |
| 1095 | + |
| 1096 | + # Merge only if non-EEG channels exist |
| 1097 | + if not non_eeg_names_ordered: |
| 1098 | + return inst_interp |
| 1099 | + |
| 1100 | + inst_non_eeg = self.copy().pick(non_eeg_names_ordered).load_data() |
| 1101 | + inst_out = inst_non_eeg.add_channels([inst_interp], force_update_info=True) |
| 1102 | + |
| 1103 | + # Reorder channels |
| 1104 | + # Insert the entire new EEG block at the position of the first EEG channel. |
| 1105 | + orig_names_arr = np.array(orig_names) |
| 1106 | + mask_eeg = np.isin(orig_names_arr, eeg_names_orig) |
| 1107 | + if mask_eeg.any(): |
| 1108 | + first_eeg_index = np.where(mask_eeg)[0][0] |
| 1109 | + pre = orig_names_arr[:first_eeg_index] |
| 1110 | + new_eeg = np.array(info_interp["ch_names"]) |
| 1111 | + post = orig_names_arr[first_eeg_index:] |
| 1112 | + post = post[~np.isin(orig_names_arr[first_eeg_index:], eeg_names_orig)] |
| 1113 | + new_order = np.concatenate((pre, new_eeg, post)).tolist() |
| 1114 | + else: |
| 1115 | + new_order = orig_names |
| 1116 | + inst_out.reorder_channels(new_order) |
| 1117 | + return inst_out |
| 1118 | + |
963 | 1119 |
|
964 | 1120 | @verbose
|
965 | 1121 | def rename_channels(info, mapping, allow_duplicates=False, *, verbose=None):
|
|
0 commit comments