|
4 | 4 | import numpy as np
|
5 | 5 | import inspect
|
6 | 6 | import importlib
|
| 7 | +import pandas as pd |
| 8 | + |
7 | 9 | from element_interface.utils import find_root_directory, find_full_path, dict_to_uuid
|
8 | 10 |
|
9 | 11 | from .readers import spikeglx, kilosort, openephys
|
@@ -796,6 +798,74 @@ def yield_unit_waveforms():
|
796 | 798 | self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True)
|
797 | 799 |
|
798 | 800 |
|
| 801 | +@schema |
| 802 | +class QualityMetrics(dj.Imported): |
| 803 | + definition = """ |
| 804 | + # Clusters and waveforms metrics |
| 805 | + -> CuratedClustering |
| 806 | + """ |
| 807 | + |
| 808 | + class Cluster(dj.Part): |
| 809 | + definition = """ |
| 810 | + # Cluster metrics for a particular unit |
| 811 | + -> master |
| 812 | + -> CuratedClustering.Unit |
| 813 | + --- |
| 814 | + firing_rate=null: float # (Hz) firing rate for a unit |
| 815 | + snr=null: float # signal-to-noise ratio for a unit |
| 816 | + presence_ratio=null: float # fraction of time in which spikes are present |
| 817 | + isi_violation=null: float # rate of ISI violation as a fraction of overall rate |
| 818 | + number_violation=null: int # total number of ISI violations |
| 819 | + amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram |
| 820 | + isolation_distance=null: float # distance to nearest cluster in Mahalanobis space |
| 821 | + l_ratio=null: float # |
| 822 | + d_prime=null: float # Classification accuracy based on LDA |
| 823 | + nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster |
| 824 | + nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster |
| 825 | + silhouette_score=null: float # Standard metric for cluster overlap |
| 826 | + max_drift=null: float # Maximum change in spike depth throughout recording |
| 827 | + cumulative_drift=null: float # Cumulative change in spike depth throughout recording |
| 828 | + contamination_rate=null: float # |
| 829 | + """ |
| 830 | + |
| 831 | + class Waveform(dj.Part): |
| 832 | + definition = """ |
| 833 | + # Waveform metrics for a particular unit |
| 834 | + -> master |
| 835 | + -> CuratedClustering.Unit |
| 836 | + --- |
| 837 | + amplitude: float # (uV) absolute difference between waveform peak and trough |
| 838 | + duration: float # (ms) time between waveform peak and trough |
| 839 | + halfwidth=null: float # (ms) spike width at half max amplitude |
| 840 | + pt_ratio=null: float # absolute amplitude of peak divided by absolute amplitude of trough relative to 0 |
| 841 | + repolarization_slope=null: float # the repolarization slope was defined by fitting a regression line to the first 30us from trough to peak |
| 842 | + recovery_slope=null: float # the recovery slope was defined by fitting a regression line to the first 30us from peak to tail |
| 843 | + spread=null: float # (um) the range with amplitude above 12-percent of the maximum amplitude along the probe |
| 844 | + velocity_above=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the top of the probe |
| 845 | + velocity_below=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the bottom of the probe |
| 846 | + """ |
| 847 | + |
| 848 | + def make(self, key): |
| 849 | + output_dir = (ClusteringTask & key).fetch1('clustering_output_dir') |
| 850 | + kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) |
| 851 | + |
| 852 | + metric_fp = kilosort_dir / 'metrics.csv' |
| 853 | + |
| 854 | + if not metric_fp.exists(): |
| 855 | + raise FileNotFoundError(f'QC metrics file not found: {metric_fp}') |
| 856 | + |
| 857 | + metrics_df = pd.read_csv(metric_fp) |
| 858 | + metrics_df.set_index('cluster_id', inplace=True) |
| 859 | + |
| 860 | + metrics_list = [ |
| 861 | + dict(metrics_df.loc[unit_key['unit']], **unit_key) |
| 862 | + for unit_key in (CuratedClustering.Unit & key).fetch('KEY')] |
| 863 | + |
| 864 | + self.insert1(key) |
| 865 | + self.Cluster.insert(metrics_list, ignore_extra_fields=True) |
| 866 | + self.Waveform.insert(metrics_list, ignore_extra_fields=True) |
| 867 | + |
| 868 | + |
799 | 869 | # ---------------- HELPER FUNCTIONS ----------------
|
800 | 870 |
|
801 | 871 | def get_spikeglx_meta_filepath(ephys_recording_key):
|
|
0 commit comments