|
4 | 4 | import numpy as np
|
5 | 5 | import inspect
|
6 | 6 | import importlib
|
| 7 | +import pandas as pd |
7 | 8 |
|
8 | 9 | from .readers import spikeglx, kilosort, openephys
|
9 | 10 | from . import probe, find_full_path, find_root_directory, dict_to_uuid
|
@@ -646,6 +647,78 @@ def yield_unit_waveforms():
|
646 | 647 | self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True)
|
647 | 648 |
|
648 | 649 |
|
| 650 | +@schema |
| 651 | +class QualityMetric(dj.Imported): |
| 652 | + definition = """ |
| 653 | + # Clusters and waveforms metrics |
| 654 | + -> CuratedClustering |
| 655 | + """ |
| 656 | + |
| 657 | + class Cluster(dj.Part): |
| 658 | + definition = """ |
| 659 | + # Cluster metrics for a particular unit |
| 660 | + -> master |
| 661 | + -> CuratedClustering.Unit |
| 662 | + --- |
| 663 | + firing_rate=null: float # (Hz) firing rate for a unit |
| 664 | + snr=null: float # signal-to-noise ratio for a unit |
| 665 | + presence_ratio=null: float # fraction of time in which spikes are present |
| 666 | + isi_violation=null: float # rate of ISI violation as a fraction of overall rate |
| 667 | + number_violation=null: int # total number of ISI violations |
| 668 | + amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram |
| 669 | + isolation_distance=null: float # distance to nearest cluster in Mahalanobis space |
| 670 | + l_ratio=null: float # |
| 671 | + d_prime=null: float # Classification accuracy based on LDA |
| 672 | + nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster |
| 673 | + nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster |
| 674 | + silhouette_score=null: float # Standard metric for cluster overlap |
| 675 | + max_drift=null: float # Maximum change in spike depth throughout recording |
| 676 | + cumulative_drift=null: float # Cumulative change in spike depth throughout recording |
| 677 | + contamination_rate=null: float # |
| 678 | + """ |
| 679 | + |
| 680 | + class Waveform(dj.Part): |
| 681 | + definition = """ |
| 682 | + # Waveform metrics for a particular unit |
| 683 | + -> master |
| 684 | + -> CuratedClustering.Unit |
| 685 | + --- |
| 686 | + amplitude: float # (uV) absolute difference between waveform peak and trough |
| 687 | + duration: float # (ms) time between waveform peak and trough |
| 688 | + halfwidth=null: float # (ms) spike width at half max amplitude |
| 689 | + pt_ratio=null: float # absolute amplitude of peak divided by absolute amplitude of trough relative to 0 |
| 690 | + repolarization_slope=null: float # the repolarization slope was defined by fitting a regression line to the first 30us from trough to peak |
| 691 | + recovery_slope=null: float # the recovery slope was defined by fitting a regression line to the first 30us from peak to tail |
| 692 | + spread=null: float # (um) the range with amplitude above 12% of the maximum amplitude along the probe |
| 693 | + velocity_above=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the top of the probe |
| 694 | + velocity_below=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the bottom of the probe |
| 695 | + """ |
| 696 | + |
| 697 | + def make(self, key): |
| 698 | + output_dir = (ClusteringTask & key).fetch1('clustering_output_dir') |
| 699 | + kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) |
| 700 | + |
| 701 | + metric_fp = kilosort_dir / 'metrics.csv' |
| 702 | + |
| 703 | + if not metric_fp.exists(): |
| 704 | + raise FileNotFoundError(f'QC metrics file not found: {metric_fp}') |
| 705 | + |
| 706 | + metrics_df = pd.read_csv(metric_fp) |
| 707 | + metrics_df.set_index('cluster_id', inplace=True) |
| 708 | + |
| 709 | + # Get all units |
| 710 | + units = {u['unit']: u for u in (CuratedClustering.Unit & key).fetch( |
| 711 | + 'KEY', order_by='unit')} |
| 712 | + |
| 713 | + metrics_list = [] |
| 714 | + for unit, unit_key in units.items(): |
| 715 | + metrics_list.append({**unit_key, **dict(metrics_df.loc[unit])}) |
| 716 | + |
| 717 | + self.insert1(key) |
| 718 | + self.Cluster.insert(metrics_list, ignore_extra_fields=True) |
| 719 | + self.Waveform.insert(metrics_list, ignore_extra_fields=True) |
| 720 | + |
| 721 | + |
649 | 722 | # ---------------- HELPER FUNCTIONS ----------------
|
650 | 723 |
|
651 | 724 | def get_spikeglx_meta_filepath(ephys_recording_key):
|
|
0 commit comments