8
8
import datajoint as dj
9
9
import numpy as np
10
10
import pandas as pd
11
- import spikeinterface as si
12
11
from element_interface .utils import dict_to_uuid , find_full_path , find_root_directory
13
- from spikeinterface import exporters , postprocessing , qualitymetrics , sorters
14
12
15
13
from . import ephys_report , probe
16
14
from .readers import kilosort , openephys , spikeglx
17
15
18
- log = dj .logger
16
+ logger = dj .logger
19
17
20
18
schema = dj .schema ()
21
19
@@ -824,7 +822,7 @@ def infer_output_dir(cls, key, relative: bool = False, mkdir: bool = False):
824
822
825
823
if mkdir :
826
824
output_dir .mkdir (parents = True , exist_ok = True )
827
- log .info (f"{ output_dir } created!" )
825
+ logger .info (f"{ output_dir } created!" )
828
826
829
827
return output_dir .relative_to (processed_dir ) if relative else output_dir
830
828
@@ -1030,16 +1028,17 @@ def make(self, key):
1030
1028
1031
1029
# Get channel and electrode-site mapping
1032
1030
electrode_query = (EphysRecording .Channel & key ).proj (..., "-channel_name" )
1033
- channel2electrode_map = electrode_query .fetch (as_dict = True )
1034
1031
channel2electrode_map : dict [int , dict ] = {
1035
- chn .pop ("channel_idx" ): chn for chn in channel2electrode_map
1032
+ chn .pop ("channel_idx" ): chn for chn in electrode_query . fetch ( as_dict = True )
1036
1033
}
1037
1034
1038
1035
# Get sorter method and create output directory.
1039
1036
sorter_name = clustering_method .replace ("." , "_" )
1040
1037
si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
1041
1038
1042
1039
if si_sorting_analyzer_dir .exists (): # Read from spikeinterface outputs
1040
+ import spikeinterface as si
1041
+
1043
1042
sorting_analyzer = si .load_sorting_analyzer (folder = si_sorting_analyzer_dir )
1044
1043
si_sorting = sorting_analyzer .sorting
1045
1044
@@ -1054,12 +1053,10 @@ def make(self, key):
1054
1053
spike_count_dict : dict [int , int ] = si_sorting .count_num_spikes_per_unit ()
1055
1054
# {unit: spike_count}
1056
1055
1057
- # reorder channel2electrode_map according to recording channel ids
1056
+ # update channel2electrode_map to match with probe's channel index
1058
1057
channel2electrode_map = {
1059
- chn_idx : channel2electrode_map [chn_idx ]
1060
- for chn_idx in sorting_analyzer .channel_ids_to_indices (
1061
- sorting_analyzer .channel_ids
1062
- )
1058
+ idx : channel2electrode_map [int (chn_idx )]
1059
+ for idx , chn_idx in enumerate (sorting_analyzer .get_probe ().contact_ids )
1063
1060
}
1064
1061
1065
1062
# Get unit id to quality label mapping
@@ -1239,13 +1236,14 @@ def make(self, key):
1239
1236
1240
1237
# Get channel and electrode-site mapping
1241
1238
electrode_query = (EphysRecording .Channel & key ).proj (..., "-channel_name" )
1242
- channel2electrode_map = electrode_query .fetch (as_dict = True )
1243
1239
channel2electrode_map : dict [int , dict ] = {
1244
- chn .pop ("channel_idx" ): chn for chn in channel2electrode_map
1240
+ chn .pop ("channel_idx" ): chn for chn in electrode_query . fetch ( as_dict = True )
1245
1241
}
1246
1242
1247
1243
si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
1248
1244
if si_sorting_analyzer_dir .exists (): # read from spikeinterface outputs
1245
+ import spikeinterface as si
1246
+
1249
1247
sorting_analyzer = si .load_sorting_analyzer (folder = si_sorting_analyzer_dir )
1250
1248
1251
1249
# Find representative channel for each unit
@@ -1256,12 +1254,10 @@ def make(self, key):
1256
1254
) # {unit: peak_channel_index}
1257
1255
unit_peak_channel = {u : chn [0 ] for u , chn in unit_peak_channel .items ()}
1258
1256
1259
- # reorder channel2electrode_map according to recording channel ids
1260
- channel_indices = sorting_analyzer .channel_ids_to_indices (
1261
- sorting_analyzer .channel_ids
1262
- ).tolist ()
1257
+ # update channel2electrode_map to match with probe's channel index
1263
1258
channel2electrode_map = {
1264
- chn_idx : channel2electrode_map [chn_idx ] for chn_idx in channel_indices
1259
+ idx : channel2electrode_map [int (chn_idx )]
1260
+ for idx , chn_idx in enumerate (sorting_analyzer .get_probe ().contact_ids )
1265
1261
}
1266
1262
1267
1263
templates = sorting_analyzer .get_extension ("templates" )
@@ -1274,12 +1270,9 @@ def yield_unit_waveforms():
1274
1270
unit_waveforms = templates .get_unit_template (
1275
1271
unit_id = unit ["unit" ], operator = "average"
1276
1272
)
1277
- peak_chn_idx = channel_indices .index (
1278
- unit_peak_channel [unit ["unit" ]]
1279
- )
1280
1273
unit_peak_waveform = {
1281
1274
** unit ,
1282
- "peak_electrode_waveform" : unit_waveforms [:, peak_chn_idx ],
1275
+ "peak_electrode_waveform" : unit_waveforms [:, unit_peak_channel [ unit [ "unit" ]] ],
1283
1276
}
1284
1277
1285
1278
unit_electrode_waveforms = [
@@ -1288,7 +1281,7 @@ def yield_unit_waveforms():
1288
1281
** channel2electrode_map [chn_idx ],
1289
1282
"waveform_mean" : unit_waveforms [:, chn_idx ],
1290
1283
}
1291
- for chn_idx in channel_indices
1284
+ for chn_idx in channel2electrode_map
1292
1285
]
1293
1286
1294
1287
yield unit_peak_waveform , unit_electrode_waveforms
@@ -1501,6 +1494,8 @@ def make(self, key):
1501
1494
1502
1495
si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
1503
1496
if si_sorting_analyzer_dir .exists (): # read from spikeinterface outputs
1497
+ import spikeinterface as si
1498
+
1504
1499
sorting_analyzer = si .load_sorting_analyzer (folder = si_sorting_analyzer_dir )
1505
1500
qc_metrics = sorting_analyzer .get_extension ("quality_metrics" ).get_data ()
1506
1501
template_metrics = sorting_analyzer .get_extension (
0 commit comments