|
4 | 4 | import numpy as np
|
5 | 5 | import inspect
|
6 | 6 | import importlib
|
| 7 | +import pandas as pd |
| 8 | + |
7 | 9 | from element_interface.utils import find_root_directory, find_full_path, dict_to_uuid
|
8 | 10 |
|
9 | 11 | from .readers import spikeglx, kilosort, openephys
|
@@ -662,6 +664,74 @@ def yield_unit_waveforms():
|
662 | 664 | self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True)
|
663 | 665 |
|
664 | 666 |
|
| 667 | +@schema |
| 668 | +class QualityMetrics(dj.Imported): |
| 669 | + definition = """ |
| 670 | + # Clusters and waveforms metrics |
| 671 | + -> CuratedClustering |
| 672 | + """ |
| 673 | + |
| 674 | + class Cluster(dj.Part): |
| 675 | + definition = """ |
| 676 | + # Cluster metrics for a particular unit |
| 677 | + -> master |
| 678 | + -> CuratedClustering.Unit |
| 679 | + --- |
| 680 | + firing_rate=null: float # (Hz) firing rate for a unit |
| 681 | + snr=null: float # signal-to-noise ratio for a unit |
| 682 | + presence_ratio=null: float # fraction of time in which spikes are present |
| 683 | + isi_violation=null: float # rate of ISI violation as a fraction of overall rate |
| 684 | + number_violation=null: int # total number of ISI violations |
| 685 | + amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram |
| 686 | + isolation_distance=null: float # distance to nearest cluster in Mahalanobis space |
| 687 | + l_ratio=null: float # |
| 688 | + d_prime=null: float # Classification accuracy based on LDA |
| 689 | + nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster |
| 690 | + nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster |
| 691 | + silhouette_score=null: float # Standard metric for cluster overlap |
| 692 | + max_drift=null: float # Maximum change in spike depth throughout recording |
| 693 | + cumulative_drift=null: float # Cumulative change in spike depth throughout recording |
| 694 | + contamination_rate=null: float # |
| 695 | + """ |
| 696 | + |
| 697 | + class Waveform(dj.Part): |
| 698 | + definition = """ |
| 699 | + # Waveform metrics for a particular unit |
| 700 | + -> master |
| 701 | + -> CuratedClustering.Unit |
| 702 | + --- |
| 703 | + amplitude: float # (uV) absolute difference between waveform peak and trough |
| 704 | + duration: float # (ms) time between waveform peak and trough |
| 705 | + halfwidth=null: float # (ms) spike width at half max amplitude |
| 706 | + pt_ratio=null: float # absolute amplitude of peak divided by absolute amplitude of trough relative to 0 |
| 707 | + repolarization_slope=null: float # the repolarization slope was defined by fitting a regression line to the first 30us from trough to peak |
| 708 | + recovery_slope=null: float # the recovery slope was defined by fitting a regression line to the first 30us from peak to tail |
| 709 | + spread=null: float # (um) the range with amplitude above 12% of the maximum amplitude along the probe |
| 710 | + velocity_above=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the top of the probe |
| 711 | + velocity_below=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the bottom of the probe |
| 712 | + """ |
| 713 | + |
| 714 | + def make(self, key): |
| 715 | + output_dir = (ClusteringTask & key).fetch1('clustering_output_dir') |
| 716 | + kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) |
| 717 | + |
| 718 | + metric_fp = kilosort_dir / 'metrics.csv' |
| 719 | + |
| 720 | + if not metric_fp.exists(): |
| 721 | + raise FileNotFoundError(f'QC metrics file not found: {metric_fp}') |
| 722 | + |
| 723 | + metrics_df = pd.read_csv(metric_fp) |
| 724 | + metrics_df.set_index('cluster_id', inplace=True) |
| 725 | + |
| 726 | + metrics_list = [ |
| 727 | + dict(metrics_df.loc[unit_key['unit']], **unit_key) |
| 728 | + for unit_key in (CuratedClustering.Unit & key).fetch('KEY')] |
| 729 | + |
| 730 | + self.insert1(key) |
| 731 | + self.Cluster.insert(metrics_list, ignore_extra_fields=True) |
| 732 | + self.Waveform.insert(metrics_list, ignore_extra_fields=True) |
| 733 | + |
| 734 | + |
665 | 735 | # ---------------- HELPER FUNCTIONS ----------------
|
666 | 736 |
|
667 | 737 | def get_spikeglx_meta_filepath(ephys_recording_key):
|
|
0 commit comments