Skip to content

Update kfda.py #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 53 additions & 72 deletions kfda/kfda.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ class Kfda(BaseEstimator, ClassifierMixin, TransformerMixin):

clf_ : The internal NearestCentroid classifier used in prediction.
"""

def __init__(self, n_components=2, kernel='linear', robustness_offset=1e-8,
**kwds):
self.kernel = kernel
Expand All @@ -64,113 +63,95 @@ def __init__(self, n_components=2, kernel='linear', robustness_offset=1e-8,
self.kernel = 'linear'

def fit(self, X, y):
"""
Fit the NearestCentroid model according to the given training data.

Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Training vector, where n_samples is the number of samples and
n_features is the number of features.

y : array, shape = [n_samples]
Target values (integers)
"""
X, y = check_X_y(X, y)
self.classes_ = unique_labels(y)
if self.n_components > self.classes_.size - 1:

# Check if n_components is valid
max_components = min(self.classes_.size - 1, X.shape[0] - 1)
if self.n_components > max_components:
warnings.warn(
"n_components > classes_.size - 1."
"Only the first classes_.size - 1 components will be valid."
f"n_components > min(classes_.size - 1, n_samples - 1). "
f"Only the first {max_components} components will be valid."
)
self.n_components_ = max_components
else:
self.n_components_ = self.n_components

self.X_ = X
self.y_ = y

y_onehot = OneHotEncoder().fit_transform(
# Use sparse_output=False for explicit dense output
y_onehot = OneHotEncoder(sparse_output=False).fit_transform(
self.y_[:, np.newaxis])

K = pairwise_kernels(
X, X, metric=self.kernel, **self.kwds)

m_classes = y_onehot.T @ K / y_onehot.T.sum(1)
# Calculate class means with proper broadcasting
y_onehot_sums = y_onehot.T.sum(axis=1, keepdims=True)
m_classes = y_onehot.T @ K / np.maximum(y_onehot_sums, 1e-10) # Avoid division by zero

indices = (y_onehot @ np.arange(self.classes_.size)).astype('i')
N = K @ (K - m_classes[indices])

# Add value to diagonal for rank robustness
N += eye(self.y_.size) * self.robustness_offset
# Use numpy's eye function for dense matrices
N += np.eye(self.y_.size) * self.robustness_offset

m_classes_centered = m_classes - K.mean(1)
# Fix the broadcasting issue
m_classes_centered = m_classes - K.mean(axis=0)
M = m_classes_centered.T @ m_classes_centered

# Find weights
w, self.weights_ = eigsh(M, self.n_components, N, which='LM')

# Compute centers
centroids_ = m_classes @ self.weights_
# Add small regularization to M for numerical stability
M += np.eye(M.shape[0]) * self.robustness_offset

# Train nearest centroid classifier
self.clf_ = NearestCentroid().fit(centroids_, self.classes_)
# Handle edge case where n_components is too large
k = min(self.n_components_, M.shape[0] - 1)
if k < 1:
k = 1 # Ensure at least one component

# Use explicit parameter names for eigsh
w, self.weights_ = eigsh(M, k=k, M=N, which='LM')

# Compute centers
self.centroids_ = m_classes @ self.weights_

# Add small jitter to centroids to ensure they're distinct
if self.centroids_.shape[0] > 1:
std_dev = np.std(self.centroids_, axis=0)
if np.any(std_dev < 1e-6):
self.centroids_ += np.random.normal(0, 1e-4, size=self.centroids_.shape)

return self

def transform(self, X):
"""Project the points in X onto the fisher directions.

Parameters
----------
X : {array-like} of shape (n_samples, n_features) to be
projected onto the fisher directions.
"""
check_is_fitted(self)
check_is_fitted(self, ['weights_', 'centroids_'])
X = check_array(X)
return pairwise_kernels(
X, self.X_, metric=self.kernel, **self.kwds
) @ self.weights_

def predict(self, X):
"""Perform classification on an array of test vectors X.

The predicted class C for each sample in X is returned.

Parameters
----------
X : array-like of shape (n_samples, n_features)

Returns
-------
C : ndarray of shape (n_samples,)
"""
check_is_fitted(self)

check_is_fitted(self, ['weights_', 'centroids_'])
X = check_array(X)

projected_points = self.transform(X)
predictions = self.clf_.predict(projected_points)

return predictions

# Use our own nearest centroid prediction instead of NearestCentroid
distances = pairwise_distances(projected_points, self.centroids_)
return self.classes_[np.argmin(distances, axis=1)]

def fit_additional(self, X, y):
"""Fit new classes without recomputing weights.

Parameters
----------
X : array-like of shape (n_new_samples, n_nfeatures)
y : array, shape = [n_samples]
Target values (integers)
"""
check_is_fitted(self)
check_is_fitted(self, ['weights_', 'centroids_'])
X, y = check_X_y(X, y)

new_classes = np.unique(y)

projections = self.transform(X)
y_onehot = OneHotEncoder().fit_transform(

y_onehot = OneHotEncoder(sparse_output=False).fit_transform(
y[:, np.newaxis])
new_centroids = y_onehot.T @ projections / y_onehot.T.sum(1)

concatenated_classes = np.concatenate([self.classes_, new_classes])
concatenated_centroids = np.concatenate(
[self.clf_.centroids_, new_centroids])

self.clf_.fit(concatenated_centroids, concatenated_classes)
y_onehot_sums = y_onehot.T.sum(axis=1, keepdims=True)
new_centroids = y_onehot.T @ projections / np.maximum(y_onehot_sums, 1e-10)

self.classes_ = np.concatenate([self.classes_, new_classes])
self.centroids_ = np.concatenate([self.centroids_, new_centroids])

return self