9
9
import numpy as np
10
10
import pandas as pd
11
11
import spikeinterface as si
12
- from element_interface .utils import dict_to_uuid , find_full_path , find_root_directory
12
+ from element_interface .utils import (dict_to_uuid , find_full_path ,
13
+ find_root_directory )
13
14
from spikeinterface import exporters , postprocessing , qualitymetrics , sorters
14
15
15
16
from . import ephys_report , probe
@@ -1022,18 +1023,19 @@ def make(self, key):
1022
1023
extremum_channel_inds = unit_peak_channel_map
1023
1024
)
1024
1025
1025
- # Get electrode info !#TODO: need to be modified
1026
+ # Get electrode & channel info
1026
1027
electrode_config_key = (
1027
1028
EphysRecording * probe .ElectrodeConfig & key
1028
1029
).fetch1 ("KEY" )
1029
1030
1030
1031
electrode_query = (
1031
1032
probe .ProbeType .Electrode * probe .ElectrodeConfig .Electrode
1032
1033
& electrode_config_key
1033
- )
1034
+ ) * (dj .U ("electrode" , "channel_idx" ) & EphysRecording .Channel )
1035
+
1034
1036
channel2electrode_map = dict (
1035
- zip (* electrode_query .fetch ("channel " , "electrode" ))
1036
- )
1037
+ zip (* electrode_query .fetch ("channel_idx " , "electrode" ))
1038
+ ) # {channel: electrode}
1037
1039
1038
1040
# Get unit id to quality label mapping
1039
1041
cluster_quality_label_map = {}
@@ -1051,24 +1053,24 @@ def make(self, key):
1051
1053
pass
1052
1054
1053
1055
# Get channel to electrode mapping
1054
- channel2depth_map = dict (zip (* electrode_query .fetch ("channel " , "y_coord" )))
1056
+ channel2depth_map = dict (zip (* electrode_query .fetch ("channel_idx " , "y_coord" ))) # {channel: depth}
1055
1057
1056
1058
peak_electrode_ind = np .array (
1057
1059
[
1058
1060
channel2electrode_map [unit_peak_channel_map [unit_id ]]
1059
1061
for unit_id in si_sorting .unit_ids
1060
1062
]
1061
- )
1063
+ ) # get the electrode where peak unit activity is recorded
1062
1064
1063
1065
# Get channel to depth mapping
1064
- electrode_depth_ind = np .array (
1066
+ channel_depth_ind = np .array (
1065
1067
[
1066
1068
channel2depth_map [unit_peak_channel_map [unit_id ]]
1067
1069
for unit_id in si_sorting .unit_ids
1068
1070
]
1069
1071
)
1070
1072
spikes ["electrode" ] = peak_electrode_ind [spikes ["unit_index" ]]
1071
- spikes ["depth" ] = electrode_depth_ind [spikes ["unit_index" ]]
1073
+ spikes ["depth" ] = channel_depth_ind [spikes ["unit_index" ]]
1072
1074
1073
1075
units = []
1074
1076
@@ -1233,11 +1235,11 @@ def make(self, key):
1233
1235
we : si .WaveformExtractor = si .load_waveforms (
1234
1236
output_dir / "waveform" , with_recording = False
1235
1237
)
1236
- unit_id_to_peak_channel_indices : dict [
1237
- int , np . ndarray
1238
- ] = si . ChannelSparsity . from_best_channels (
1239
- we , 1 , peak_sign = "neg"
1240
- ). unit_id_to_channel_indices # {unit: peak_channel_index}
1238
+ unit_id_to_peak_channel_indices : dict [int , np . ndarray ] = (
1239
+ si . ChannelSparsity . from_best_channels (
1240
+ we , 1 , peak_sign = "neg"
1241
+ ). unit_id_to_channel_indices
1242
+ ) # {unit: peak_channel_index}
1241
1243
1242
1244
units = (CuratedClustering .Unit & key ).fetch ("KEY" , order_by = "unit" )
1243
1245
0 commit comments