diff --git a/pybalu/feature_analysis/jfisher.py b/pybalu/feature_analysis/jfisher.py index 27e8ae0..1669464 100644 --- a/pybalu/feature_analysis/jfisher.py +++ b/pybalu/feature_analysis/jfisher.py @@ -5,12 +5,11 @@ def jfisher(features, classification, p=None): m = features.shape[1] - norm = classification.ravel() - classification.min() - max_class = norm.max() + 1 - if p is None: - p = np.ones(shape=(max_class, 1)) / max_class - + classes = np.unique(classification) + size = classes.shape[0] + p = np.ones(shape=(size, 1)) / size + # Centroid of all samples features_mean = features.mean(0) @@ -20,8 +19,8 @@ def jfisher(features, classification, p=None): # covariance between classes cov_b = np.zeros(shape=(m, m)) - for k in range(max_class): - ii = (norm == k) # indices from class k + for k in range(size): + ii = (classification.ravel() == classes[k]) # indices from class k class_features = features[ii,:] # samples of class k class_mean = class_features.mean(0) # centroid of class k class_cov = np.cov(class_features, rowvar=False) # covariance of class k diff --git a/pybalu/feature_analysis/score.py b/pybalu/feature_analysis/score.py index cc86868..061599b 100644 --- a/pybalu/feature_analysis/score.py +++ b/pybalu/feature_analysis/score.py @@ -8,11 +8,6 @@ def score(features, classification, *, method='fisher', param=None): - if param is None: - dn = classification.max() - classification.min() + 1 # number of classes - p = np.ones((dn, 1)) / dn - else: - p = param if method == 'mi': # mutual information raise NotImplementedError() @@ -27,7 +22,7 @@ def score(features, classification, *, method='fisher', param=None): # fisher elif method == 'fisher': - return jfisher(features, classification, p) + return jfisher(features, classification) elif method == 'sp100': return sp100(features, classification)