From 81d453439f39a39bcd605df50af3ba6e3c2e1b76 Mon Sep 17 00:00:00 2001 From: Yoav Navon Date: Sun, 7 Apr 2019 21:57:47 -0400 Subject: [PATCH] Add suport for arbitrary labels in fisher --- pybalu/feature_analysis/jfisher.py | 13 ++++++------- pybalu/feature_analysis/score.py | 7 +------ 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/pybalu/feature_analysis/jfisher.py b/pybalu/feature_analysis/jfisher.py index 27e8ae0..1669464 100644 --- a/pybalu/feature_analysis/jfisher.py +++ b/pybalu/feature_analysis/jfisher.py @@ -5,12 +5,11 @@ def jfisher(features, classification, p=None): m = features.shape[1] - norm = classification.ravel() - classification.min() - max_class = norm.max() + 1 - if p is None: - p = np.ones(shape=(max_class, 1)) / max_class - + classes = np.unique(classification) + size = classes.shape[0] + p = np.ones(shape=(size, 1)) / size + # Centroid of all samples features_mean = features.mean(0) @@ -20,8 +19,8 @@ def jfisher(features, classification, p=None): # covariance between classes cov_b = np.zeros(shape=(m, m)) - for k in range(max_class): - ii = (norm == k) # indices from class k + for k in range(size): + ii = (classification.ravel() == classes[k]) # indices from class k class_features = features[ii,:] # samples of class k class_mean = class_features.mean(0) # centroid of class k class_cov = np.cov(class_features, rowvar=False) # covariance of class k diff --git a/pybalu/feature_analysis/score.py b/pybalu/feature_analysis/score.py index cc86868..061599b 100644 --- a/pybalu/feature_analysis/score.py +++ b/pybalu/feature_analysis/score.py @@ -8,11 +8,6 @@ def score(features, classification, *, method='fisher', param=None): - if param is None: - dn = classification.max() - classification.min() + 1 # number of classes - p = np.ones((dn, 1)) / dn - else: - p = param if method == 'mi': # mutual information raise NotImplementedError() @@ -27,7 +22,7 @@ def score(features, classification, *, method='fisher', param=None): # fisher elif method == 'fisher': - return jfisher(features, classification, p) + return jfisher(features, classification) elif method == 'sp100': return sp100(features, classification)