@@ -1248,6 +1248,16 @@ def make(self, key):
1248
1248
"kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method
1249
1249
)
1250
1250
1251
+ # Get channel and electrode-site mapping
1252
+ channel_info = (
1253
+ (EphysRecording .Channel & key )
1254
+ .proj (..., "-channel_name" )
1255
+ .fetch (as_dict = True , order_by = "channel_idx" )
1256
+ )
1257
+ channel_info : dict [int , dict ] = {
1258
+ ch .pop ("channel_idx" ): ch for ch in channel_info
1259
+ } # e.g., {0: {'subject': 'sglx', 'session_id': 912231859, 'insertion_number': 1, 'electrode_config_hash': UUID('8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee'), 'probe_type': 'neuropixels 1.0 - 3A', 'electrode': 0}}
1260
+
1251
1261
if (
1252
1262
output_dir / sorter_name / "waveform"
1253
1263
).exists (): # read from spikeinterface outputs
@@ -1256,27 +1266,12 @@ def make(self, key):
1256
1266
we : si .WaveformExtractor = si .load_waveforms (
1257
1267
waveform_dir , with_recording = False
1258
1268
)
1259
- unit_id_to_peak_channel_indices : dict [int , np .ndarray ] = (
1269
+ unit_id_to_peak_channel_map : dict [int , np .ndarray ] = (
1260
1270
si .ChannelSparsity .from_best_channels (
1261
1271
we , 1 , peak_sign = "neg"
1262
1272
).unit_id_to_channel_indices
1263
1273
) # {unit: peak_channel_index}
1264
1274
1265
- units = (CuratedClustering .Unit & key ).fetch ("KEY" , order_by = "unit" )
1266
-
1267
- # Get electrode info
1268
- electrode_config_key = (
1269
- EphysRecording * probe .ElectrodeConfig & key
1270
- ).fetch1 ("KEY" )
1271
-
1272
- electrode_query = (
1273
- probe .ProbeType .Electrode .proj () * probe .ElectrodeConfig .Electrode
1274
- & electrode_config_key
1275
- )
1276
- electrode_info = electrode_query .fetch (
1277
- "KEY" , order_by = "electrode" , as_dict = True
1278
- )
1279
-
1280
1275
# Get mean waveform for each unit from all channels
1281
1276
mean_waveforms = we .get_all_templates (
1282
1277
mode = "average"
@@ -1285,26 +1280,26 @@ def make(self, key):
1285
1280
unit_peak_waveform = []
1286
1281
unit_electrode_waveforms = []
1287
1282
1288
- for unit in units :
1283
+ for unit in ( CuratedClustering . Unit & key ). fetch ( "KEY" , order_by = "unit" ) :
1289
1284
unit_peak_waveform .append (
1290
1285
{
1291
1286
** unit ,
1292
1287
"peak_electrode_waveform" : we .get_template (
1293
1288
unit_id = unit ["unit" ], mode = "average" , force_dense = True
1294
- )[:, unit_id_to_peak_channel_indices [unit ["unit" ]][0 ]],
1289
+ )[:, unit_id_to_peak_channel_map [unit ["unit" ]][0 ]],
1295
1290
}
1296
1291
)
1297
1292
1298
1293
unit_electrode_waveforms .extend (
1299
1294
[
1300
1295
{
1301
1296
** unit ,
1302
- ** e ,
1297
+ ** channel_info [ c ] ,
1303
1298
"waveform_mean" : mean_waveforms [
1304
- unit ["unit" ], :, e [ "electrode" ]
1299
+ unit ["unit" ] - 1 , :, c
1305
1300
],
1306
1301
}
1307
- for e in electrode_info
1302
+ for c in channel_info
1308
1303
]
1309
1304
)
1310
1305
@@ -1319,16 +1314,6 @@ def make(self, key):
1319
1314
EphysRecording * ProbeInsertion & key
1320
1315
).fetch1 ("acq_software" , "probe" )
1321
1316
1322
- # Get channel and electrode-site mapping
1323
- channel_info = (
1324
- (EphysRecording .Channel & key )
1325
- .proj (..., "-channel_name" )
1326
- .fetch (as_dict = True , order_by = "channel_idx" )
1327
- )
1328
- channel_info : dict [int , dict ] = {
1329
- ch .pop ("channel_idx" ): ch for ch in channel_info
1330
- } # e.g., {0: {'subject': 'sglx', 'session_id': 912231859, 'insertion_number': 1, 'electrode_config_hash': UUID('8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee'), 'probe_type': 'neuropixels 1.0 - 3A', 'electrode': 0}}
1331
-
1332
1317
# Get all units
1333
1318
units = {
1334
1319
u ["unit" ]: u
0 commit comments