From 727a4818dce34dc80609fb6368d0cdb1e1b04eca Mon Sep 17 00:00:00 2001 From: bendeez Date: Sun, 15 Jun 2025 10:03:40 -0500 Subject: [PATCH 1/8] implemented multiclass for svc --- src/svm/svc.rs | 342 ++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 295 insertions(+), 47 deletions(-) diff --git a/src/svm/svc.rs b/src/svm/svc.rs index cc5a0beb..e8b93cbf 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -58,10 +58,11 @@ //! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; //! //! let knl = Kernels::linear(); -//! let params = &SVCParameters::default().with_c(200.0).with_kernel(knl); -//! let svc = SVC::fit(&x, &y, params).unwrap(); +//! let parameters = &SVCParameters::default().with_c(200.0).with_kernel(knl); +//! let svc = SVC::fit(&x, &y, parameters, None).unwrap(); //! //! let y_hat = svc.predict(&x).unwrap(); +//! //! ``` //! //! ## References: @@ -84,12 +85,196 @@ use serde::{Deserialize, Serialize}; use crate::api::{PredictorBorrow, SupervisedEstimatorBorrow}; use crate::error::{Failed, FailedError}; -use crate::linalg::basic::arrays::{Array1, Array2, MutArray}; +use crate::linalg::basic::arrays::{Array, Array1, Array2, MutArray}; use crate::numbers::basenum::Number; use crate::numbers::realnum::RealNumber; use crate::rand_custom::get_rng_impl; use crate::svm::Kernel; +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] +/// Configuration for a multi-class Support Vector Machine (SVM) classifier. +/// This struct holds the indices of the data points relevant to a specific binary +/// classification problem within a multi-class context, and the two classes +/// being discriminated. +pub struct MultiClassConfig { + /// The indices of the data points from the original dataset that belong to the two `classes`. + indices: Vec, + /// A tuple representing the two classes that this configuration is designed to distinguish. + classes: (TY, TY), +} + +impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1> + SupervisedEstimatorBorrow<'a, X, Y, SVCParameters> + for MultiClassSVC<'a, TX, TY, X, Y> +{ + /// Creates a new, empty `MultiClassSVC` instance. + fn new() -> Self { + Self { + classifiers: Option::None, + } + } + + /// Fits the `MultiClassSVC` model to the provided data and parameters. + /// + /// This method delegates the fitting process to the inherent `MultiClassSVC::fit` method. + /// + /// # Arguments + /// * `x` - A reference to the input features (2D array). + /// * `y` - A reference to the target labels (1D array). + /// * `parameters` - A reference to the `SVCParameters` controlling the SVM training. + /// + /// # Returns + /// A `Result` indicating success (`Self`) or failure (`Failed`). + fn fit( + x: &'a X, + y: &'a Y, + parameters: &'a SVCParameters, + ) -> Result { + MultiClassSVC::fit(x, y, parameters) + } +} + +impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1> + PredictorBorrow<'a, X, TX> for MultiClassSVC<'a, TX, TY, X, Y> +{ + /// Predicts the class labels for new data points. + /// + /// This method delegates the prediction process to the inherent `MultiClassSVC::predict` method. + /// It unwraps the inner `Result` from `MultiClassSVC::predict`, assuming that + /// the prediction will always succeed after a successful fit. + /// + /// # Arguments + /// * `x` - A reference to the input features (2D array) for which to make predictions. + /// + /// # Returns + /// A `Result` containing a `Vec` of predicted class labels (`TX`) or a `Failed` error. + fn predict(&self, x: &'a X) -> Result, Failed> { + Ok(self.predict(x).unwrap()) + } +} + +/// A multi-class Support Vector Machine (SVM) classifier. +/// +/// This struct implements a multi-class SVM using the "one-vs-one" strategy, +/// where a separate binary SVC classifier is trained for every pair of classes. +/// +/// # Type Parameters +/// * `'a` - Lifetime parameter for borrowed data. +/// * `TX` - The numeric type of the input features (must implement `Number` and `RealNumber`). +/// * `TY` - The numeric type of the target labels (must implement `Number` and `Ord`). +/// * `X` - The type representing the 2D array of input features (e.g., a matrix). +/// * `Y` - The type representing the 1D array of target labels (e.g., a vector). +pub struct MultiClassSVC< + 'a, + TX: Number + RealNumber, + TY: Number + Ord, + X: Array2, + Y: Array1, +> { + /// An optional vector of binary `SVC` classifiers. + classifiers: Option>>, +} + +impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1> + MultiClassSVC<'a, TX, TY, X, Y> +{ + /// Fits the `MultiClassSVC` model to the provided data using a one-vs-one strategy. + /// + /// This method identifies all unique classes in the target labels `y` and then + /// trains a binary `SVC` for every unique pair of classes. For each pair, it + /// extracts the relevant data points and their labels, and then trains a + /// specialized `SVC` for that binary classification task. + /// + /// # Arguments + /// * `x` - A reference to the input features (2D array). + /// * `y` - A reference to the target labels (1D array). + /// * `parameters` - A reference to the `SVCParameters` controlling the SVM training for each individual binary classifier. + /// + /// + /// # Returns + /// A `Result` indicating success (`MultiClassSVC`) or failure (`Failed`). + pub fn fit( + x: &'a X, + y: &'a Y, + parameters: &'a SVCParameters, + ) -> Result, Failed> { + let unique_classes = y.unique(); + let mut classifiers = Vec::new(); + // Iterate through all unique pairs of classes (one-vs-one strategy) + for i in 0..unique_classes.len() { + for j in i..unique_classes.len() { + if i == j { + continue; + } + let class0 = unique_classes[j]; + let class1 = unique_classes[i]; + + let mut indices = Vec::new(); + // Collect indices of data points belonging to the current pair of classes + for (index, v) in y.iterator(0).enumerate() { + if *v == class0 || *v == class1 { + indices.push(index) + } + } + let classes = (class0, class1); + let multiclass_config = MultiClassConfig { classes, indices }; + // Fit a binary SVC for the current pair of classes + let svc = SVC::fit(x, y, parameters, Some(multiclass_config)).unwrap(); + classifiers.push(svc); + } + } + Ok(Self { + classifiers: Some(classifiers), + }) + } + + /// Predicts the class labels for new data points using the trained multi-class SVM. + /// + /// This method uses a "voting" scheme (majority vote) among all the binary + /// classifiers to determine the final prediction for each data point. + /// + /// # Arguments + /// * `x` - A reference to the input features (2D array) for which to make predictions. + /// + /// # Returns + /// A `Result` containing a `Vec` of predicted class labels (`TX`) or a `Failed` error. + /// + pub fn predict(&self, x: &X) -> Result, Failed> { + // Initialize a HashMap for each data point to store votes for each class + let mut polls = vec![HashMap::new(); x.shape().0]; + // Retrieve the trained binary classifiers + let classifiers = self.classifiers.as_ref().unwrap(); + + // Iterate through each binary classifier + for i in 0..classifiers.len() { + let svc = classifiers.get(i).unwrap(); + let predictions = svc.predict(x).unwrap(); // call SVC::predict for each binary classifier + + // For each prediction from the current binary classifier + for (j, prediction) in predictions.iter().enumerate() { + let prediction = prediction.to_i32().unwrap(); + let poll = polls.get_mut(j).unwrap(); // Get the poll for the current data point + // Increment the vote for the predicted class + if let Some(count) = poll.get_mut(&prediction) { + *count += 1 + } else { + poll.insert(prediction, 1); + } + } + } + + // Determine the final prediction for each data point based on majority vote + Ok(polls + .iter() + .map(|v| { + // Find the class with the maximum votes for each data point + TX::from(*v.iter().max_by_key(|(_, class)| *class).unwrap().0).unwrap() + }) + .collect()) + } +} + #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug)] /// SVC Parameters @@ -123,7 +308,7 @@ pub struct SVCParameters, Y: Array1> { - classes: Option>, + classes: Option<(TY, TY)>, instances: Option>>, #[cfg_attr(feature = "serde", serde(skip))] parameters: Option<&'a SVCParameters>, @@ -152,7 +337,9 @@ struct Cache, Y: Array1 struct Optimizer<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1> { x: &'a X, y: &'a Y, + indices: Option>, parameters: &'a SVCParameters, + classes: &'a (TY, TY), svmin: usize, svmax: usize, gmin: TX, @@ -180,12 +367,12 @@ impl, Y: Array1> self.tol = tol; self } + /// The kernel function. pub fn with_kernel(mut self, kernel: K) -> Self { self.kernel = Some(Box::new(kernel)); self } - /// Seed for the pseudo random number generator. pub fn with_seed(mut self, seed: Option) -> Self { self.seed = seed; @@ -226,7 +413,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1 y: &'a Y, parameters: &'a SVCParameters, ) -> Result { - SVC::fit(x, y, parameters) + SVC::fit(x, y, parameters, None) } } @@ -249,6 +436,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2 + 'a, Y: Array x: &'a X, y: &'a Y, parameters: &'a SVCParameters, + multiclass_config: Option>, ) -> Result, Failed> { let (n, _) = x.shape(); @@ -265,27 +453,22 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2 + 'a, Y: Array )); } - let classes = y.unique(); - - if classes.len() != 2 { - return Err(Failed::fit(&format!( - "Incorrect number of classes: {}", - classes.len() - ))); - } - - // Make sure class labels are either 1 or -1 - for e in y.iterator(0) { - let y_v = e.to_i32().unwrap(); - if y_v != -1 && y_v != 1 { - return Err(Failed::because( - FailedError::ParametersError, - "Class labels must be 1 or -1", - )); + let (indices, classes) = if let Some(multiclass_config) = multiclass_config { + let classes = multiclass_config.classes; + (Some(multiclass_config.indices), classes) + } else { + let classes = y.unique(); + if classes.len() != 2 { + return Err(Failed::fit(&format!( + "Incorrect number of classes: {}", + classes.len() + ))); } - } + (None, (classes[0], classes[1])) + }; - let optimizer: Optimizer<'_, TX, TY, X, Y> = Optimizer::new(x, y, parameters); + let optimizer: Optimizer<'_, TX, TY, X, Y> = + Optimizer::new(x, y, indices, parameters, &classes); let (support_vectors, weight, b) = optimizer.optimize(); @@ -305,9 +488,9 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2 + 'a, Y: Array let mut y_hat: Vec = self.decision_function(x)?; for i in 0..y_hat.len() { - let cls_idx = match *y_hat.get(i).unwrap() > TX::zero() { - false => TX::from(self.classes.as_ref().unwrap()[0]).unwrap(), - true => TX::from(self.classes.as_ref().unwrap()[1]).unwrap(), + let cls_idx = match *y_hat.get(i) > TX::zero() { + false => TX::from(self.classes.as_ref().unwrap().0).unwrap(), + true => TX::from(self.classes.as_ref().unwrap().1).unwrap(), }; y_hat.set(i, cls_idx); @@ -445,14 +628,18 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1 fn new( x: &'a X, y: &'a Y, + indices: Option>, parameters: &'a SVCParameters, + classes: &'a (TY, TY), ) -> Optimizer<'a, TX, TY, X, Y> { let (n, _) = x.shape(); Optimizer { x, y, + indices, parameters, + classes, svmin: 0, svmax: 0, gmin: ::max_value(), @@ -478,7 +665,12 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1 for i in self.permutate(n) { x.clear(); x.extend(self.x.get_row(i).iterator(0).take(n).copied()); - self.process(i, &x, *self.y.get(i), &mut cache); + let y = if *self.y.get(i) == self.classes.1 { + 1 + } else { + -1 + } as f64; + self.process(i, &x, y, &mut cache); loop { self.reprocess(tol, &mut cache); self.find_min_max_gradient(); @@ -514,14 +706,16 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1 for i in self.permutate(n) { x.clear(); x.extend(self.x.get_row(i).iterator(0).take(n).copied()); - if *self.y.get(i) == TY::one() && cp < few { - if self.process(i, &x, *self.y.get(i), cache) { + let y = if *self.y.get(i) == self.classes.1 { + 1 + } else { + -1 + } as f64; + if y == 1.0 && cp < few { + if self.process(i, &x, y, cache) { cp += 1; } - } else if *self.y.get(i) == TY::from(-1).unwrap() - && cn < few - && self.process(i, &x, *self.y.get(i), cache) - { + } else if y == -1.0 && cn < few && self.process(i, &x, y, cache) { cn += 1; } @@ -531,14 +725,14 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1 } } - fn process(&mut self, i: usize, x: &[TX], y: TY, cache: &mut Cache) -> bool { + fn process(&mut self, i: usize, x: &[TX], y: f64, cache: &mut Cache) -> bool { for j in 0..self.sv.len() { if self.sv[j].index == i { return true; } } - let mut g: f64 = y.to_f64().unwrap(); + let mut g = y; let mut cache_values: Vec<((usize, usize), TX)> = Vec::new(); @@ -559,8 +753,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1 self.find_min_max_gradient(); if self.gmin < self.gmax - && ((y > TY::zero() && g < self.gmin.to_f64().unwrap()) - || (y < TY::zero() && g > self.gmax.to_f64().unwrap())) + && ((y > 0.0 && g < self.gmin.to_f64().unwrap()) + || (y < 0.0 && g > self.gmax.to_f64().unwrap())) { return false; } @@ -590,7 +784,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1 ), ); - if y > TY::zero() { + if y > 0.0 { self.smo(None, Some(0), TX::zero(), cache); } else { self.smo(Some(0), None, TX::zero(), cache); @@ -647,7 +841,6 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1 let gmin = self.gmin; let mut idxs_to_drop: HashSet = HashSet::new(); - self.sv.retain(|v| { if v.alpha == 0f64 && ((TX::from(v.grad).unwrap() >= gmax && TX::zero() >= TX::from(v.cmax).unwrap()) @@ -666,7 +859,11 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1 fn permutate(&self, n: usize) -> Vec { let mut rng = get_rng_impl(self.parameters.seed); - let mut range: Vec = (0..n).collect(); + let mut range = if let Some(indices) = self.indices.clone() { + indices + } else { + (0..n).collect::>() + }; range.shuffle(&mut rng); range } @@ -930,6 +1127,55 @@ mod tests { use crate::metrics::accuracy; use crate::svm::Kernels; + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn svc_multiclass_fit_predict() { + let x = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + &[5.7, 2.8, 4.5, 1.3], + &[6.3, 3.3, 4.7, 1.6], + &[4.9, 2.4, 3.3, 1.0], + &[6.6, 2.9, 4.6, 1.3], + &[5.2, 2.7, 3.9, 1.4], + ]) + .unwrap(); + + let y: Vec = vec![0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2]; + + let knl = Kernels::linear(); + let parameters = SVCParameters::default() + .with_c(200.0) + .with_kernel(knl) + .with_seed(Some(100)); + + let y_hat = MultiClassSVC::fit(&x, &y, ¶meters) + .and_then(|lr| lr.predict(&x)) + .unwrap(); + + let acc = accuracy(&y, &(y_hat.iter().map(|e| e.to_i32().unwrap()).collect())); + + assert!( + acc >= 0.9, + "Multiclass accuracy ({acc}) is not larger or equal to 0.9" + ); + } #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test @@ -965,12 +1211,12 @@ mod tests { ]; let knl = Kernels::linear(); - let params = SVCParameters::default() + let parameters = SVCParameters::default() .with_c(200.0) .with_kernel(knl) .with_seed(Some(100)); - let y_hat = SVC::fit(&x, &y, ¶ms) + let y_hat = SVC::fit(&x, &y, ¶meters, None) .and_then(|lr| lr.predict(&x)) .unwrap(); let acc = accuracy(&y, &(y_hat.iter().map(|e| e.to_i32().unwrap()).collect())); @@ -1005,6 +1251,7 @@ mod tests { &SVCParameters::default() .with_c(200.0) .with_kernel(Kernels::linear()), + None, ) .and_then(|lr| lr.decision_function(&x2)) .unwrap(); @@ -1061,6 +1308,7 @@ mod tests { &SVCParameters::default() .with_c(1.0) .with_kernel(Kernels::rbf().with_gamma(0.7)), + None, ) .and_then(|lr| lr.predict(&x)) .unwrap(); @@ -1106,8 +1354,8 @@ mod tests { ]; let knl = Kernels::linear(); - let params = SVCParameters::default().with_kernel(knl); - let svc = SVC::fit(&x, &y, ¶ms).unwrap(); + let parameters = SVCParameters::default().with_kernel(knl); + let svc = SVC::fit(&x, &y, ¶meters).unwrap(); // serialization let deserialized_svc: SVC<'_, f64, i32, _, _> = @@ -1115,4 +1363,4 @@ mod tests { assert_eq!(svc, deserialized_svc); } -} +} \ No newline at end of file From 93225268cfc1085140ade2c4b9f3221004dda84f Mon Sep 17 00:00:00 2001 From: bendeez Date: Sun, 15 Jun 2025 12:43:41 -0500 Subject: [PATCH 2/8] modified the multiclass svc so it doesnt modify the current api --- src/svm/svc.rs | 154 +++++++++++++++++++++++++++++++++++++------------ 1 file changed, 116 insertions(+), 38 deletions(-) diff --git a/src/svm/svc.rs b/src/svm/svc.rs index e8b93cbf..f5aa7611 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -59,7 +59,7 @@ //! //! let knl = Kernels::linear(); //! let parameters = &SVCParameters::default().with_c(200.0).with_kernel(knl); -//! let svc = SVC::fit(&x, &y, parameters, None).unwrap(); +//! let svc = SVC::fit(&x, &y, parameters).unwrap(); //! //! let y_hat = svc.predict(&x).unwrap(); //! @@ -97,7 +97,7 @@ use crate::svm::Kernel; /// This struct holds the indices of the data points relevant to a specific binary /// classification problem within a multi-class context, and the two classes /// being discriminated. -pub struct MultiClassConfig { +struct MultiClassConfig { /// The indices of the data points from the original dataset that belong to the two `classes`. indices: Vec, /// A tuple representing the two classes that this configuration is designed to distinguish. @@ -220,7 +220,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1 let classes = (class0, class1); let multiclass_config = MultiClassConfig { classes, indices }; // Fit a binary SVC for the current pair of classes - let svc = SVC::fit(x, y, parameters, Some(multiclass_config)).unwrap(); + let svc = SVC::multiclass_fit(x, y, parameters, multiclass_config).unwrap(); classifiers.push(svc); } } @@ -413,7 +413,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1 y: &'a Y, parameters: &'a SVCParameters, ) -> Result { - SVC::fit(x, y, parameters, None) + SVC::fit(x, y, parameters) } } @@ -428,18 +428,109 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1 impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2 + 'a, Y: Array1 + 'a> SVC<'a, TX, TY, X, Y> { - /// Fits SVC to your data. - /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. - /// * `y` - class labels - /// * `parameters` - optional parameters, use `Default::default()` to set parameters to default values. + /// Fits a binary Support Vector Classifier (SVC) specifically for multi-class scenarios. + /// + /// This function is intended to be called by a multi-class strategy (e.g., one-vs-one) + /// to train individual binary SVCs. It takes a `MultiClassConfig` which specifies + /// the two classes this SVC should discriminate and the subset of data indices + /// relevant to these classes. It then delegates the actual optimization and fitting + /// to `optimize_and_fit`. + /// + /// # Arguments + /// * `x` - A reference to the input features (2D array) of the training data. + /// * `y` - A reference to the target labels (1D array) of the training data. + /// * `parameters` - A reference to the `SVCParameters` controlling the training process + /// (e.g., kernel, C-value, tolerance). + /// * `multiclass_config` - A `MultiClassConfig` struct containing: + /// - `classes`: A tuple `(class0, class1)` specifying the two classes this SVC + /// should distinguish. + /// - `indices`: A `Vec` containing the indices of the data points in `x` and `y` + /// that belong to either `class0` or `class1`. + /// + /// # Returns + /// A `Result` which is: + /// - `Ok(SVC<'a, TX, TY, X, Y>)`: A new, fitted binary SVC instance. + /// - `Err(Failed)`: If the fitting process encounters an error (e.g., invalid parameters). + fn multiclass_fit( + x: &'a X, + y: &'a Y, + parameters: &'a SVCParameters, + multiclass_config: MultiClassConfig, + ) -> Result, Failed> { + let classes = multiclass_config.classes; + let indices = multiclass_config.indices; + let svc = Self::optimize_and_fit(x, y, parameters, classes, Some(indices)); + svc + } + + /// Fits a binary Support Vector Classifier (SVC) to the provided data. + /// + /// This is the primary `fit` method for a standalone binary SVC. It expects + /// the target labels `y` to contain exactly two unique classes. If more or + /// fewer than two classes are found, it returns an error. It then extracts + /// these two classes and proceeds to optimize and fit the SVC model. + /// + /// # Arguments + /// * `x` - A reference to the input features (2D array) of the training data. + /// * `y` - A reference to the target labels (1D array) of the training data. + /// `y` must contain exactly two unique class labels. + /// * `parameters` - A reference to the `SVCParameters` controlling the training process. + /// + /// # Returns + /// A `Result` which is: + /// - `Ok(SVC<'a, TX, TY, X, Y>)`: A new, fitted binary SVC instance. + /// - `Err(Failed)`: If the number of unique classes in `y` is not exactly two, + /// or if the underlying optimization fails. pub fn fit( x: &'a X, y: &'a Y, parameters: &'a SVCParameters, - multiclass_config: Option>, ) -> Result, Failed> { - let (n, _) = x.shape(); + let classes = y.unique(); + // Validate that there are exactly two unique classes in the target labels. + if classes.len() != 2 { + return Err(Failed::fit(&format!( + "Incorrect number of classes: {}. A binary SVC requires exactly two classes.", + classes.len() + ))); + } + let classes = (classes[0], classes[1]); + let svc = Self::optimize_and_fit(x, y, parameters, classes, None); + svc + } + /// Internal function to optimize and fit the Support Vector Classifier. + /// + /// This is the core logic for training a binary SVC. It performs several checks + /// (e.g., kernel presence, data shape consistency) and then initializes an + /// `Optimizer` to find the support vectors, weights (`w`), and bias (`b`). + /// + /// # Arguments + /// * `x` - A reference to the input features (2D array) of the training data. + /// * `y` - A reference to the target labels (1D array) of the training data. + /// * `parameters` - A reference to the `SVCParameters` defining the SVM model's configuration. + /// * `classes` - A tuple `(class0, class1)` representing the two distinct class labels + /// that the SVC will learn to separate. + /// * `indices` - An `Option>`. If `Some`, it contains the specific indices + /// of data points from `x` and `y` that should be used for training this + /// binary classifier. If `None`, all data points in `x` and `y` are considered. + /// + /// # Returns + /// A `Result` which is: + /// - `Ok(SVC<'a, TX, TY, X, Y>)`: A new `SVC` instance populated with the learned model + /// components (support vectors, weights, bias). + /// - `Err(Failed)`: If any of the validation checks fail (e.g., missing kernel, + /// mismatched data shapes), or if the optimization process fails. + fn optimize_and_fit( + x: &'a X, + y: &'a Y, + parameters: &'a SVCParameters, + classes: (TY, TY), + indices: Option>, + ) -> Result, Failed> { + let (n_samples, _) = x.shape(); + + // Validate that a kernel has been defined in the parameters. if parameters.kernel.is_none() { return Err(Failed::because( FailedError::ParametersError, @@ -447,41 +538,30 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2 + 'a, Y: Array )); } - if n != y.shape() { + // Validate that the number of samples in X matches the number of labels in Y. + if n_samples != y.shape() { return Err(Failed::fit( - "Number of rows of X doesn\'t match number of rows of Y", + "Number of rows of X doesn't match number of rows of Y", )); } - let (indices, classes) = if let Some(multiclass_config) = multiclass_config { - let classes = multiclass_config.classes; - (Some(multiclass_config.indices), classes) - } else { - let classes = y.unique(); - if classes.len() != 2 { - return Err(Failed::fit(&format!( - "Incorrect number of classes: {}", - classes.len() - ))); - } - (None, (classes[0], classes[1])) - }; - let optimizer: Optimizer<'_, TX, TY, X, Y> = Optimizer::new(x, y, indices, parameters, &classes); - let (support_vectors, weight, b) = optimizer.optimize(); + // Perform the optimization to find the support vectors, weight vector, and bias. + // This is where the core SVM algorithm (e.g., SMO) would run. + let (support_vectors, weight, b) = optimizer.optimize(); + // Construct and return the fitted SVC model. Ok(SVC::<'a> { - classes: Some(classes), - instances: Some(support_vectors), - parameters: Some(parameters), - w: Some(weight), - b: Some(b), - phantomdata: PhantomData, + classes: Some(classes), // Store the two classes the SVC was trained on. + instances: Some(support_vectors), // Store the data points that are support vectors. + parameters: Some(parameters), // Reference to the parameters used for fitting. + w: Some(weight), // The learned weight vector (for linear kernels). + b: Some(b), // The learned bias term. + phantomdata: PhantomData, // Placeholder for type parameters not directly stored. }) } - /// Predicts estimated class labels from `x` /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. pub fn predict(&self, x: &'a X) -> Result, Failed> { @@ -1216,7 +1296,7 @@ mod tests { .with_kernel(knl) .with_seed(Some(100)); - let y_hat = SVC::fit(&x, &y, ¶meters, None) + let y_hat = SVC::fit(&x, &y, ¶meters) .and_then(|lr| lr.predict(&x)) .unwrap(); let acc = accuracy(&y, &(y_hat.iter().map(|e| e.to_i32().unwrap()).collect())); @@ -1251,7 +1331,6 @@ mod tests { &SVCParameters::default() .with_c(200.0) .with_kernel(Kernels::linear()), - None, ) .and_then(|lr| lr.decision_function(&x2)) .unwrap(); @@ -1308,7 +1387,6 @@ mod tests { &SVCParameters::default() .with_c(1.0) .with_kernel(Kernels::rbf().with_gamma(0.7)), - None, ) .and_then(|lr| lr.predict(&x)) .unwrap(); @@ -1363,4 +1441,4 @@ mod tests { assert_eq!(svc, deserialized_svc); } -} \ No newline at end of file +} From 2e251c7c7314e5659c128289c11a7a2d3fdb956d Mon Sep 17 00:00:00 2001 From: bendeez Date: Sun, 15 Jun 2025 13:39:44 -0500 Subject: [PATCH 3/8] fixed linting --- src/svm/svc.rs | 33 ++++++++++----------------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/src/svm/svc.rs b/src/svm/svc.rs index f5aa7611..1b0e2e0f 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -141,8 +141,6 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1 /// Predicts the class labels for new data points. /// /// This method delegates the prediction process to the inherent `MultiClassSVC::predict` method. - /// It unwraps the inner `Result` from `MultiClassSVC::predict`, assuming that - /// the prediction will always succeed after a successful fit. /// /// # Arguments /// * `x` - A reference to the input features (2D array) for which to make predictions. @@ -439,13 +437,10 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2 + 'a, Y: Array /// # Arguments /// * `x` - A reference to the input features (2D array) of the training data. /// * `y` - A reference to the target labels (1D array) of the training data. - /// * `parameters` - A reference to the `SVCParameters` controlling the training process - /// (e.g., kernel, C-value, tolerance). + /// * `parameters` - A reference to the `SVCParameters` controlling the training process (e.g., kernel, C-value, tolerance). /// * `multiclass_config` - A `MultiClassConfig` struct containing: - /// - `classes`: A tuple `(class0, class1)` specifying the two classes this SVC - /// should distinguish. - /// - `indices`: A `Vec` containing the indices of the data points in `x` and `y` - /// that belong to either `class0` or `class1`. + /// - `classes`: A tuple `(class0, class1)` specifying the two classes this SVC should distinguish. + /// - `indices`: A `Vec` containing the indices of the data points in `x` and `y that belong to either `class0` or `class1`.` /// /// # Returns /// A `Result` which is: @@ -472,15 +467,13 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2 + 'a, Y: Array /// /// # Arguments /// * `x` - A reference to the input features (2D array) of the training data. - /// * `y` - A reference to the target labels (1D array) of the training data. - /// `y` must contain exactly two unique class labels. + /// * `y` - A reference to the target labels (1D array) of the training data. `y` must contain exactly two unique class labels. /// * `parameters` - A reference to the `SVCParameters` controlling the training process. /// /// # Returns /// A `Result` which is: /// - `Ok(SVC<'a, TX, TY, X, Y>)`: A new, fitted binary SVC instance. - /// - `Err(Failed)`: If the number of unique classes in `y` is not exactly two, - /// or if the underlying optimization fails. + /// - `Err(Failed)`: If the number of unique classes in `y` is not exactly two, or if the underlying optimization fails. pub fn fit( x: &'a X, y: &'a Y, @@ -509,18 +502,12 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2 + 'a, Y: Array /// * `x` - A reference to the input features (2D array) of the training data. /// * `y` - A reference to the target labels (1D array) of the training data. /// * `parameters` - A reference to the `SVCParameters` defining the SVM model's configuration. - /// * `classes` - A tuple `(class0, class1)` representing the two distinct class labels - /// that the SVC will learn to separate. - /// * `indices` - An `Option>`. If `Some`, it contains the specific indices - /// of data points from `x` and `y` that should be used for training this - /// binary classifier. If `None`, all data points in `x` and `y` are considered. - /// + /// * `classes` - A tuple `(class0, class1)` representing the two distinct class labels that the SVC will learn to separate. + /// * `indices` - An `Option>`. If `Some`, it contains the specific indices of data points from `x` and `y` that should be used for training this binary classifier. If `None`, all data points in `x` and `y` are considered. /// # Returns /// A `Result` which is: - /// - `Ok(SVC<'a, TX, TY, X, Y>)`: A new `SVC` instance populated with the learned model - /// components (support vectors, weights, bias). - /// - `Err(Failed)`: If any of the validation checks fail (e.g., missing kernel, - /// mismatched data shapes), or if the optimization process fails. + /// - `Ok(SVC<'a, TX, TY, X, Y>)`: A new `SVC` instance populated with the learned model components (support vectors, weights, bias). + /// - `Err(Failed)`: If any of the validation checks fail (e.g., missing kernel, mismatched data shapes), or if the optimization process fails. fn optimize_and_fit( x: &'a X, y: &'a Y, @@ -550,7 +537,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2 + 'a, Y: Array // Perform the optimization to find the support vectors, weight vector, and bias. // This is where the core SVM algorithm (e.g., SMO) would run. - let (support_vectors, weight, b) = optimizer.optimize(); + let (support_vectors, weight, b) = optimizer.optimize(); // Construct and return the fitted SVC model. Ok(SVC::<'a> { From 132ba7ef1d1be8489c6292cbc9506d3b28c7c4b2 Mon Sep 17 00:00:00 2001 From: bendeez Date: Sun, 15 Jun 2025 14:34:16 -0500 Subject: [PATCH 4/8] resolved issue --- src/svm/svc.rs | 163 +++++++++++++++++++++++++------------------------ 1 file changed, 82 insertions(+), 81 deletions(-) diff --git a/src/svm/svc.rs b/src/svm/svc.rs index 1b0e2e0f..d72ecdac 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -426,38 +426,6 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1 impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2 + 'a, Y: Array1 + 'a> SVC<'a, TX, TY, X, Y> { - /// Fits a binary Support Vector Classifier (SVC) specifically for multi-class scenarios. - /// - /// This function is intended to be called by a multi-class strategy (e.g., one-vs-one) - /// to train individual binary SVCs. It takes a `MultiClassConfig` which specifies - /// the two classes this SVC should discriminate and the subset of data indices - /// relevant to these classes. It then delegates the actual optimization and fitting - /// to `optimize_and_fit`. - /// - /// # Arguments - /// * `x` - A reference to the input features (2D array) of the training data. - /// * `y` - A reference to the target labels (1D array) of the training data. - /// * `parameters` - A reference to the `SVCParameters` controlling the training process (e.g., kernel, C-value, tolerance). - /// * `multiclass_config` - A `MultiClassConfig` struct containing: - /// - `classes`: A tuple `(class0, class1)` specifying the two classes this SVC should distinguish. - /// - `indices`: A `Vec` containing the indices of the data points in `x` and `y that belong to either `class0` or `class1`.` - /// - /// # Returns - /// A `Result` which is: - /// - `Ok(SVC<'a, TX, TY, X, Y>)`: A new, fitted binary SVC instance. - /// - `Err(Failed)`: If the fitting process encounters an error (e.g., invalid parameters). - fn multiclass_fit( - x: &'a X, - y: &'a Y, - parameters: &'a SVCParameters, - multiclass_config: MultiClassConfig, - ) -> Result, Failed> { - let classes = multiclass_config.classes; - let indices = multiclass_config.indices; - let svc = Self::optimize_and_fit(x, y, parameters, classes, Some(indices)); - svc - } - /// Fits a binary Support Vector Classifier (SVC) to the provided data. /// /// This is the primary `fit` method for a standalone binary SVC. It expects @@ -492,6 +460,38 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2 + 'a, Y: Array svc } + /// Fits a binary Support Vector Classifier (SVC) specifically for multi-class scenarios. + /// + /// This function is intended to be called by a multi-class strategy (e.g., one-vs-one) + /// to train individual binary SVCs. It takes a `MultiClassConfig` which specifies + /// the two classes this SVC should discriminate and the subset of data indices + /// relevant to these classes. It then delegates the actual optimization and fitting + /// to `optimize_and_fit`. + /// + /// # Arguments + /// * `x` - A reference to the input features (2D array) of the training data. + /// * `y` - A reference to the target labels (1D array) of the training data. + /// * `parameters` - A reference to the `SVCParameters` controlling the training process (e.g., kernel, C-value, tolerance). + /// * `multiclass_config` - A `MultiClassConfig` struct containing: + /// - `classes`: A tuple `(class0, class1)` specifying the two classes this SVC should distinguish. + /// - `indices`: A `Vec` containing the indices of the data points in `x` and `y that belong to either `class0` or `class1`.` + /// + /// # Returns + /// A `Result` which is: + /// - `Ok(SVC<'a, TX, TY, X, Y>)`: A new, fitted binary SVC instance. + /// - `Err(Failed)`: If the fitting process encounters an error (e.g., invalid parameters). + fn multiclass_fit( + x: &'a X, + y: &'a Y, + parameters: &'a SVCParameters, + multiclass_config: MultiClassConfig, + ) -> Result, Failed> { + let classes = multiclass_config.classes; + let indices = multiclass_config.indices; + let svc = Self::optimize_and_fit(x, y, parameters, classes, Some(indices)); + svc + } + /// Internal function to optimize and fit the Support Vector Classifier. /// /// This is the core logic for training a binary SVC. It performs several checks @@ -1194,55 +1194,6 @@ mod tests { use crate::metrics::accuracy; use crate::svm::Kernels; - #[cfg_attr( - all(target_arch = "wasm32", not(target_os = "wasi")), - wasm_bindgen_test::wasm_bindgen_test - )] - #[test] - fn svc_multiclass_fit_predict() { - let x = DenseMatrix::from_2d_array(&[ - &[5.1, 3.5, 1.4, 0.2], - &[4.9, 3.0, 1.4, 0.2], - &[4.7, 3.2, 1.3, 0.2], - &[4.6, 3.1, 1.5, 0.2], - &[5.0, 3.6, 1.4, 0.2], - &[5.4, 3.9, 1.7, 0.4], - &[4.6, 3.4, 1.4, 0.3], - &[5.0, 3.4, 1.5, 0.2], - &[4.4, 2.9, 1.4, 0.2], - &[4.9, 3.1, 1.5, 0.1], - &[7.0, 3.2, 4.7, 1.4], - &[6.4, 3.2, 4.5, 1.5], - &[6.9, 3.1, 4.9, 1.5], - &[5.5, 2.3, 4.0, 1.3], - &[6.5, 2.8, 4.6, 1.5], - &[5.7, 2.8, 4.5, 1.3], - &[6.3, 3.3, 4.7, 1.6], - &[4.9, 2.4, 3.3, 1.0], - &[6.6, 2.9, 4.6, 1.3], - &[5.2, 2.7, 3.9, 1.4], - ]) - .unwrap(); - - let y: Vec = vec![0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2]; - - let knl = Kernels::linear(); - let parameters = SVCParameters::default() - .with_c(200.0) - .with_kernel(knl) - .with_seed(Some(100)); - - let y_hat = MultiClassSVC::fit(&x, &y, ¶meters) - .and_then(|lr| lr.predict(&x)) - .unwrap(); - - let acc = accuracy(&y, &(y_hat.iter().map(|e| e.to_i32().unwrap()).collect())); - - assert!( - acc >= 0.9, - "Multiclass accuracy ({acc}) is not larger or equal to 0.9" - ); - } #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test @@ -1383,6 +1334,56 @@ mod tests { assert!(acc >= 0.9, "accuracy ({acc}) is not larger or equal to 0.9"); } + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn svc_multiclass_fit_predict() { + let x = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + &[5.7, 2.8, 4.5, 1.3], + &[6.3, 3.3, 4.7, 1.6], + &[4.9, 2.4, 3.3, 1.0], + &[6.6, 2.9, 4.6, 1.3], + &[5.2, 2.7, 3.9, 1.4], + ]) + .unwrap(); + + let y: Vec = vec![0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2]; + + let knl = Kernels::linear(); + let parameters = SVCParameters::default() + .with_c(200.0) + .with_kernel(knl) + .with_seed(Some(100)); + + let y_hat = MultiClassSVC::fit(&x, &y, ¶meters) + .and_then(|lr| lr.predict(&x)) + .unwrap(); + + let acc = accuracy(&y, &(y_hat.iter().map(|e| e.to_i32().unwrap()).collect())); + + assert!( + acc >= 0.9, + "Multiclass accuracy ({acc}) is not larger or equal to 0.9" + ); + } + #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test From 729dd5a4e2a14939444a9e7568fa90f86adb570d Mon Sep 17 00:00:00 2001 From: bendeez Date: Tue, 24 Jun 2025 13:02:42 -0500 Subject: [PATCH 5/8] implemented heirarchal clustering --- src/cluster/hierarchal_clustering.rs | 192 +++++++++++++++++++++++++++ src/cluster/mod.rs | 1 + src/linalg/basic/arrays.rs | 1 + 3 files changed, 194 insertions(+) create mode 100644 src/cluster/hierarchal_clustering.rs diff --git a/src/cluster/hierarchal_clustering.rs b/src/cluster/hierarchal_clustering.rs new file mode 100644 index 00000000..f6fe5635 --- /dev/null +++ b/src/cluster/hierarchal_clustering.rs @@ -0,0 +1,192 @@ +use crate::{ + error::Failed, + linalg::basic::arrays::{Array, Array1, Array2}, + metrics::distance::euclidian::Euclidian, + numbers::basenum::Number, +}; +use std::collections::HashMap; +use std::{f32, iter::zip, marker::PhantomData}; + +pub enum Linkage { + Ward, +} + +pub struct AgglomerativeClusteringParameters { + pub n_clusters: usize, + pub linkage: Linkage, +} + +impl AgglomerativeClusteringParameters { + pub fn with_n_clusters(mut self, n_clusters: usize) -> Self { + self.n_clusters = n_clusters; + self + } + + pub fn with_linkage(mut self, linkage: Linkage) -> Self { + self.linkage = linkage; + self + } +} + +pub struct AgglomerativeClustering, Y: Array1> { + pub labels: Vec, + _phantom_tx: PhantomData, + _phantom_ty: PhantomData, + _phantom_x: PhantomData, + _phantom_y: PhantomData, +} + +impl, Y: Array1> AgglomerativeClustering { + fn compute_cluster_variance( + data: &X, + cluster1_indices: &Vec, + cluster2_indices: &Vec, + ) -> f32 { + let (_, num_features) = data.shape(); + let mut sum_row = vec![0 as f32; num_features]; + for cluster in vec![cluster1_indices, cluster2_indices] { + for index in cluster { + sum_row = zip(sum_row, data.get_row(*index).iterator(0)) + .map(|(v, x)| v + x.to_f32().unwrap()) + .collect(); + } + } + let clusters_len = cluster1_indices.len() + cluster2_indices.len(); + let mean_row: Vec = sum_row.iter().map(|v| *v/clusters_len as f32).collect(); + let mut variance = 0.0; + for cluster in vec![cluster1_indices, cluster2_indices] { + for index in cluster { + let squared_distance: f32 = zip(data.get_row(*index).iterator(0), mean_row.iter()) + .map(|(x, v)| (x.to_f32().unwrap() - *v).powf(2.0)) + .sum(); + variance += squared_distance; + } + } + variance + } + + fn compute_distance<'a>( + data: &X, + linkage: &Linkage, + cache: &mut HashMap<&'a Vec, f32>, + cluster1_indices: &'a Vec, + cluster2_indices: &'a Vec, + ) -> f32 { + match linkage { + Linkage::Ward => { + let cluster1_variance = if let Some(variance) = cache.get(&cluster1_indices) { + *variance + } else { + let cluster1_variance = + Self::compute_cluster_variance(&data, &cluster1_indices, &vec![]); + cache.insert(&cluster1_indices, cluster1_variance); + cluster1_variance + }; + let cluster2_variance = if let Some(variance) = cache.get(&cluster2_indices) { + *variance + } else { + let cluster2_variance = + Self::compute_cluster_variance(&data, &cluster2_indices, &vec![]); + cache.insert(&cluster2_indices, cluster2_variance); + cluster2_variance + }; + let both_cluster_variance = cluster1_variance + cluster2_variance; + let distance = both_cluster_variance - cluster1_variance - cluster2_variance; + distance + } + } + } + pub fn fit( + data: &X, + parameters: AgglomerativeClusteringParameters, + ) -> Result, Failed> { + let mut cache = HashMap::new(); + let mut matrix = Vec::new(); + let (num_rows, _) = data.shape(); + let mut indices_mapping = HashMap::new(); + for i in 0..num_rows { + indices_mapping.insert(i, vec![i]); + } + for i in 0..num_rows { + let mut row = Vec::new(); + for j in i + 1..num_rows { + let distance = Self::compute_distance( + data, + ¶meters.linkage, + &mut cache, + indices_mapping.get(&i).unwrap(), + &indices_mapping.get(&j).unwrap(), + ); + row.push(distance); + } + matrix.push(row); + } + while indices_mapping.len() > parameters.n_clusters { + let mut min_distance = f32::INFINITY; + let mut pairs = (0, 0); + for (i, row) in matrix.iter().enumerate() { + if !indices_mapping.contains_key(&i) { + continue; + } + for (j, distance) in row.iter().enumerate() { + let j_offset = i + 1 + j; + if !indices_mapping.contains_key(&j_offset) { + continue; + } + if *distance < min_distance { + min_distance = *distance; + pairs = (i, j_offset); + } + } + } + let (i, j_offset) = pairs; + let cluster1_indices = indices_mapping.remove(&i).unwrap(); + let cluster2_indices = indices_mapping.remove(&j_offset).unwrap(); + cache.remove(&cluster1_indices); + cache.remove(&cluster2_indices); + let mut combined_cluster_indices = cluster1_indices; + combined_cluster_indices.extend(cluster2_indices); + indices_mapping.insert(i, combined_cluster_indices); + matrix[i] = (0..matrix[i].len()) + .map(|j| { + if let Some(other_cluster_indices) = indices_mapping.get(i + 1 + j) { + Self::compute_distance( + &data, + ¶meters.linkage, + &mut cache, + &combined_cluster_indices, + &other_cluster_indices, + ) + } else { + 0.0 + } + }) + .collect(); + for g in 0..i { + let offset = i - g - 1; + if let Some(other_cluster_indices) = indices_mapping.get(&offset) { + matrix[g][offset] = Self::compute_distance( + &data, + ¶meters.linkage, + &mut cache, + &combined_cluster_indices, + &other_cluster_indices, + ) + } + } + } + let mut labels = vec![0; num_rows]; + for (i, cluster) in indices_mapping.keys().enumerate() { + for index in indices_mapping[cluster] { + labels[index] = i + } + } + Ok(Self { + labels, + _phantom_tx: PhantomData, + _phantom_ty: PhantomData, + _phantom_x: PhantomData, + _phantom_y: PhantomData, + }) + } +} diff --git a/src/cluster/mod.rs b/src/cluster/mod.rs index be6ef9f0..5b7f5535 100644 --- a/src/cluster/mod.rs +++ b/src/cluster/mod.rs @@ -4,5 +4,6 @@ //! are more similar to other data points in the same group than those in other groups. In simple words, the aim is to segregate groups with similar traits and assign them into clusters. pub mod dbscan; +pub mod hierarchal_clustering; /// An iterative clustering algorithm that aims to find local maxima in each iteration. pub mod kmeans; diff --git a/src/linalg/basic/arrays.rs b/src/linalg/basic/arrays.rs index a5abe634..03c27906 100644 --- a/src/linalg/basic/arrays.rs +++ b/src/linalg/basic/arrays.rs @@ -170,6 +170,7 @@ pub trait ArrayView1: Array { .map(|(s, o)| *s * *o) .sum() } + /// return sum of all value of the view fn sum(&self) -> T where From 33b1d01ffdf2ad2743752bd993f4e661f77b3fc8 Mon Sep 17 00:00:00 2001 From: bendeez Date: Tue, 24 Jun 2025 13:49:06 -0500 Subject: [PATCH 6/8] implemented heirarchal clustering --- ...hierarchal_clustering.rs => hierarchal.rs} | 36 ++++++++++--------- src/cluster/mod.rs | 2 +- 2 files changed, 20 insertions(+), 18 deletions(-) rename src/cluster/{hierarchal_clustering.rs => hierarchal.rs} (87%) diff --git a/src/cluster/hierarchal_clustering.rs b/src/cluster/hierarchal.rs similarity index 87% rename from src/cluster/hierarchal_clustering.rs rename to src/cluster/hierarchal.rs index f6fe5635..dc284b56 100644 --- a/src/cluster/hierarchal_clustering.rs +++ b/src/cluster/hierarchal.rs @@ -1,7 +1,6 @@ use crate::{ error::Failed, - linalg::basic::arrays::{Array, Array1, Array2}, - metrics::distance::euclidian::Euclidian, + linalg::basic::arrays::{Array1, Array2}, numbers::basenum::Number, }; use std::collections::HashMap; @@ -68,29 +67,29 @@ impl, Y: Array1> AgglomerativeClusteri fn compute_distance<'a>( data: &X, linkage: &Linkage, - cache: &mut HashMap<&'a Vec, f32>, - cluster1_indices: &'a Vec, - cluster2_indices: &'a Vec, + cache: &mut HashMap, f32>, + cluster1_indices: &Vec, + cluster2_indices: &Vec, ) -> f32 { match linkage { Linkage::Ward => { - let cluster1_variance = if let Some(variance) = cache.get(&cluster1_indices) { + let cluster1_variance = if let Some(variance) = cache.get(cluster1_indices) { *variance } else { let cluster1_variance = Self::compute_cluster_variance(&data, &cluster1_indices, &vec![]); - cache.insert(&cluster1_indices, cluster1_variance); + cache.insert(cluster1_indices.clone(), cluster1_variance); cluster1_variance }; - let cluster2_variance = if let Some(variance) = cache.get(&cluster2_indices) { + let cluster2_variance = if let Some(variance) = cache.get(cluster2_indices) { *variance } else { let cluster2_variance = Self::compute_cluster_variance(&data, &cluster2_indices, &vec![]); - cache.insert(&cluster2_indices, cluster2_variance); + cache.insert(cluster2_indices.clone(), cluster2_variance); cluster2_variance }; - let both_cluster_variance = cluster1_variance + cluster2_variance; + let both_cluster_variance = Self::compute_cluster_variance(&data, &cluster1_indices, &cluster2_indices); let distance = both_cluster_variance - cluster1_variance - cluster2_variance; distance } @@ -115,7 +114,7 @@ impl, Y: Array1> AgglomerativeClusteri ¶meters.linkage, &mut cache, indices_mapping.get(&i).unwrap(), - &indices_mapping.get(&j).unwrap(), + indices_mapping.get(&j).unwrap(), ); row.push(distance); } @@ -146,10 +145,10 @@ impl, Y: Array1> AgglomerativeClusteri cache.remove(&cluster2_indices); let mut combined_cluster_indices = cluster1_indices; combined_cluster_indices.extend(cluster2_indices); - indices_mapping.insert(i, combined_cluster_indices); matrix[i] = (0..matrix[i].len()) .map(|j| { - if let Some(other_cluster_indices) = indices_mapping.get(i + 1 + j) { + let j_offset = i + 1 + j; + if let Some(other_cluster_indices) = indices_mapping.get(&j_offset) { Self::compute_distance( &data, ¶meters.linkage, @@ -164,7 +163,7 @@ impl, Y: Array1> AgglomerativeClusteri .collect(); for g in 0..i { let offset = i - g - 1; - if let Some(other_cluster_indices) = indices_mapping.get(&offset) { + if let Some(other_cluster_indices) = indices_mapping.get(&g) { matrix[g][offset] = Self::compute_distance( &data, ¶meters.linkage, @@ -174,11 +173,14 @@ impl, Y: Array1> AgglomerativeClusteri ) } } + indices_mapping.insert(i, combined_cluster_indices); } let mut labels = vec![0; num_rows]; - for (i, cluster) in indices_mapping.keys().enumerate() { - for index in indices_mapping[cluster] { - labels[index] = i + let mut sorted_keys: Vec<&usize> = indices_mapping.keys().collect(); + sorted_keys.sort(); + for (i, cluster) in sorted_keys.iter().enumerate() { + for index in indices_mapping.get(cluster).unwrap() { + labels[*index] = i } } Ok(Self { diff --git a/src/cluster/mod.rs b/src/cluster/mod.rs index 5b7f5535..c32a6366 100644 --- a/src/cluster/mod.rs +++ b/src/cluster/mod.rs @@ -4,6 +4,6 @@ //! are more similar to other data points in the same group than those in other groups. In simple words, the aim is to segregate groups with similar traits and assign them into clusters. pub mod dbscan; -pub mod hierarchal_clustering; /// An iterative clustering algorithm that aims to find local maxima in each iteration. pub mod kmeans; +pub mod hierarchal; From d84620ee9116f91946458ea3aa303f763b362786 Mon Sep 17 00:00:00 2001 From: bendeez Date: Tue, 24 Jun 2025 16:10:30 -0500 Subject: [PATCH 7/8] added tests --- src/cluster/hierarchal.rs | 442 ++++++++++++++++++++++++++++++++++++-- src/cluster/kmeans.rs | 2 + src/cluster/mod.rs | 2 +- 3 files changed, 427 insertions(+), 19 deletions(-) diff --git a/src/cluster/hierarchal.rs b/src/cluster/hierarchal.rs index dc284b56..8cad1a18 100644 --- a/src/cluster/hierarchal.rs +++ b/src/cluster/hierarchal.rs @@ -1,62 +1,172 @@ +//! # Hierarchical Clustering +//! +//! Hierarchical clustering is a method of cluster analysis that builds a hierarchy of clusters, either from the bottom up or the top down. Unlike partitioning algorithms such as K-Means, it does not require the number of clusters to be specified beforehand. Instead, it joduces a tree-like structure called a dendrogram that illustrates the nested grouping of data points. A desired number of clusters can then be obtained by "cutting" the dendrogram at a specific level. +//! +//! This implementation uses the agglomerative (bottom-up) approach, which is the most common strategy for hierarchical clustering. +//! +//! The agglomerative algorithm works as follows: +//! +//! 1. Initialization: Each data point starts in its own individual cluster. +//! 2. Iterative Merging: In each step, the two closest clusters are identified and merged into a single new cluster. +//! 3. Termination: This process is repeated until all data points are contained within a single, all-encompassing cluster, thus completing the hierarchy. +//! +//! A critical choice in this process is the linkage criterion, which defines how the distance between two clusters is measured. This choice significantly influences the shape of the clusters and the structure of the dendrogram. This implementation uses Ward's Linkage, which minimizes the increase in the total within-cluster variance when merging clusters. It is particularly effective at identifying compact, spherical clusters. +//! +//! Example: +//! +//! use smartcore::linalg::basic::matrix::DenseMatrix; +//! use smartcore::cluster::hierarchical::{AgglomerativeClustering, AgglomerativeClusteringParameters, Linkage}; +//! let x = DenseMatrix::from_2d_array(&[ +//! &[5.1, 3.5, 1.4, 0.2], +//! &[4.9, 3.0, 1.4, 0.2], +//! &[4.7, 3.2, 1.3, 0.2], +//! &[4.6, 3.1, 1.5, 0.2], +//! &[5.0, 3.6, 1.4, 0.2], +//! &[5.4, 3.9, 1.7, 0.4], +//! &[4.6, 3.4, 1.4, 0.3], +//! &[5.0, 3.4, 1.5, 0.2], +//! &[4.4, 2.9, 1.4, 0.2], +//! &[4.9, 3.1, 1.5, 0.1], +//! &[7.0, 3.2, 4.7, 1.4], +//! &[6.4, 3.2, 4.5, 1.5], +//! &[6.9, 3.1, 4.9, 1.5], +//! &[5.5, 2.3, 4.0, 1.3], +//! &[6.5, 2.8, 4.6, 1.5], +//! &[5.7, 2.8, 4.5, 1.3], +//! &[6.3, 3.3, 4.7, 1.6], +//! &[4.9, 2.4, 3.3, 1.0], +//! &[6.6, 2.9, 4.6, 1.3], +//! &[5.2, 2.7, 3.9, 1.4], +//! &[6.3, 2.5, 5.0, 1.9], +//! &[6.5, 3.0, 5.2, 2.0], +//! &[6.2, 3.4, 5.4, 2.3], +//! &[5.9, 3.0, 5.1, 1.8], +//! ]).unwrap(); +//! let params = AgglomerativeClusteringParameters { +//! n_clusters: 3, +//! linkage: Linkage::Ward, +//! }; +//! let clustering_result = AgglomerativeClustering::, Vec>::fit(&x, params).unwrap(); +//! let y_hat = clustering_result.labels; +//! ## References: +//! +//! * "An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., Chapter 10 +//! * "Hierarchical Grouping to Optimize an Objective Function", Ward, J. H., Jr., 1963 +//! * "Finding Groups in Data: An Introduction to Cluster Analysis", Kaufman, L., Rousseeuw, P.J., 1990 use crate::{ error::Failed, linalg::basic::arrays::{Array1, Array2}, numbers::basenum::Number, }; +use crate::api::{UnsupervisedEstimator}; use std::collections::HashMap; use std::{f32, iter::zip, marker::PhantomData}; +use std::collections::HashSet; +/// Defines the linkage criterion to use for Agglomerative Clustering. +/// +/// The linkage criterion determines which distance to use between sets of observations. +/// The algorithm will merge the pairs of clusters that minimize this criterion. pub enum Linkage { + /// Ward's minimum variance method. + /// + /// Ward's method minimizes the sum of squared differences within all clusters. + /// It is a variance-minimizing approach and in this sense is similar to the k-means + /// objective function but tackled with an agglomerative hierarchical approach. Ward, } +/// Parameters for the Agglomerative Clustering algorithm. +/// +/// This struct is used to configure the clustering process. It can be instantiated +/// and then modified using a builder pattern. pub struct AgglomerativeClusteringParameters { + /// The number of clusters to find. pub n_clusters: usize, + /// The linkage criterion to use. pub linkage: Linkage, } impl AgglomerativeClusteringParameters { + /// Sets the number of clusters. + /// + /// # Arguments + /// + /// * `n_clusters` - The desired number of clusters. pub fn with_n_clusters(mut self, n_clusters: usize) -> Self { self.n_clusters = n_clusters; self } + /// Sets the linkage criterion. + /// + /// # Arguments + /// + /// * `linkage` - The linkage method to use for clustering. pub fn with_linkage(mut self, linkage: Linkage) -> Self { self.linkage = linkage; self } } +/// Represents the result of an Agglomerative Clustering operation. +/// +/// This struct holds the cluster labels assigned to each sample in the input data. pub struct AgglomerativeClustering, Y: Array1> { + /// A vector where `labels[i]` is the cluster identifier for the i-th sample. pub labels: Vec, + /// Phantom data to hold the generic type `TX`. _phantom_tx: PhantomData, + /// Phantom data to hold the generic type `TY`. _phantom_ty: PhantomData, + /// Phantom data to hold the generic type `X`. _phantom_x: PhantomData, + /// Phantom data to hold the generic type `Y`. _phantom_y: PhantomData, } impl, Y: Array1> AgglomerativeClustering { + /// Computes the variance of a potential cluster. + /// + /// This function calculates the sum of squared distances from each point in the + /// combined cluster to the cluster's mean. This is a key component of Ward's linkage. + /// + /// # Arguments + /// + /// * `data` - The input data matrix. + /// * `cluster1_indices` - Indices of the data points in the first cluster. + /// * `cluster2_indices` - Indices of the data points in the second cluster (can be empty). + /// + /// # Returns + /// + /// The variance of the combined cluster as an `f32`. fn compute_cluster_variance( data: &X, cluster1_indices: &Vec, cluster2_indices: &Vec, - ) -> f32 { + ) -> f64 { let (_, num_features) = data.shape(); - let mut sum_row = vec![0 as f32; num_features]; + let mut sum_row = vec![0 as f64; num_features]; + + // Sum up all feature vectors for the points in the given clusters for cluster in vec![cluster1_indices, cluster2_indices] { for index in cluster { sum_row = zip(sum_row, data.get_row(*index).iterator(0)) - .map(|(v, x)| v + x.to_f32().unwrap()) + .map(|(v, x)| v + x.to_f64().unwrap()) .collect(); } } + let clusters_len = cluster1_indices.len() + cluster2_indices.len(); - let mean_row: Vec = sum_row.iter().map(|v| *v/clusters_len as f32).collect(); + // Calculate the mean of the combined cluster + let mean_row: Vec = sum_row.iter().map(|v| *v / clusters_len as f64).collect(); + let mut variance = 0.0; + // Calculate the sum of squared distances from each point to the mean for cluster in vec![cluster1_indices, cluster2_indices] { for index in cluster { - let squared_distance: f32 = zip(data.get_row(*index).iterator(0), mean_row.iter()) - .map(|(x, v)| (x.to_f32().unwrap() - *v).powf(2.0)) + let squared_distance: f64 = zip(data.get_row(*index).iterator(0), mean_row.iter()) + .map(|(x, v)| (x.to_f64().unwrap() - *v).powf(2.0)) .sum(); variance += squared_distance; } @@ -64,15 +174,33 @@ impl, Y: Array1> AgglomerativeClusteri variance } + /// Computes the distance between two clusters based on the specified linkage. + /// + /// # Arguments + /// + /// * `data` - The input data matrix. + /// * `linkage` - The linkage criterion to use. + /// * `cache` - A mutable HashMap to store and retrieve pre-computed cluster variances for performance. + /// * `cluster1_indices` - Indices of the data points in the first cluster. + /// * `cluster2_indices` - Indices of the data points in the second cluster. + /// + /// # Returns + /// + /// The distance between the two clusters as an `f32`. fn compute_distance<'a>( data: &X, linkage: &Linkage, - cache: &mut HashMap, f32>, + cache: &mut HashMap, f64>, cluster1_indices: &Vec, cluster2_indices: &Vec, - ) -> f32 { + ) -> f64 { match linkage { Linkage::Ward => { + // For Ward's method, the distance is the increase in variance that would result + // from merging the two clusters. + // distance = variance(cluster1 U cluster2) - variance(cluster1) - variance(cluster2) + + // Get variance of the first cluster, from cache or by computing it let cluster1_variance = if let Some(variance) = cache.get(cluster1_indices) { *variance } else { @@ -81,6 +209,8 @@ impl, Y: Array1> AgglomerativeClusteri cache.insert(cluster1_indices.clone(), cluster1_variance); cluster1_variance }; + + // Get variance of the second cluster, from cache or by computing it let cluster2_variance = if let Some(variance) = cache.get(cluster2_indices) { *variance } else { @@ -89,12 +219,38 @@ impl, Y: Array1> AgglomerativeClusteri cache.insert(cluster2_indices.clone(), cluster2_variance); cluster2_variance }; - let both_cluster_variance = Self::compute_cluster_variance(&data, &cluster1_indices, &cluster2_indices); + + // Compute variance of the merged cluster + let both_cluster_variance = + Self::compute_cluster_variance(&data, &cluster1_indices, &cluster2_indices); + + // The increase in variance is the distance let distance = both_cluster_variance - cluster1_variance - cluster2_variance; distance } } } + + /// Fit the agglomerative clustering model to the data. + /// + /// This method performs hierarchical clustering using a bottom-up approach. Each observation + /// starts in its own cluster, and clusters are successively merged together. The process + /// continues until the desired number of clusters is reached. + /// + /// # Arguments + /// + /// * `data` - A 2D array-like structure of shape (n_samples, n_features). + /// * `parameters` - The parameters for the clustering algorithm, including `n_clusters` and `linkage`. + /// + /// # Returns + /// + /// A `Result` which is `Ok` containing an `AgglomerativeClustering` instance with the + /// final cluster labels, or an `Err` with a `Failed` error type if something goes wrong. + /// + + /// let clustering_result = AgglomerativeClustering::fit(&data, params).unwrap(); + /// // `clustering_result.labels` will contain the cluster assignment for each row of data. + /// ``` pub fn fit( data: &X, parameters: AgglomerativeClusteringParameters, @@ -102,10 +258,16 @@ impl, Y: Array1> AgglomerativeClusteri let mut cache = HashMap::new(); let mut matrix = Vec::new(); let (num_rows, _) = data.shape(); + + // Initially, each data point is its own cluster. + // `indices_mapping` maps a cluster ID to the list of original data point indices it contains. let mut indices_mapping = HashMap::new(); for i in 0..num_rows { indices_mapping.insert(i, vec![i]); } + + // Pre-compute the initial distance matrix for all pairs of points. + // This is an upper triangular matrix to save space. for i in 0..num_rows { let mut row = Vec::new(); for j in i + 1..num_rows { @@ -120,17 +282,21 @@ impl, Y: Array1> AgglomerativeClusteri } matrix.push(row); } + + // Iteratively merge clusters until `n_clusters` is reached. while indices_mapping.len() > parameters.n_clusters { - let mut min_distance = f32::INFINITY; + let mut min_distance = f64::INFINITY; let mut pairs = (0, 0); + + // Find the two closest clusters. for (i, row) in matrix.iter().enumerate() { if !indices_mapping.contains_key(&i) { - continue; + continue; // Skip clusters that have been merged. } for (j, distance) in row.iter().enumerate() { - let j_offset = i + 1 + j; + let j_offset = i + 1 + j; // Get the real index for the second cluster. if !indices_mapping.contains_key(&j_offset) { - continue; + continue; // Skip clusters that have been merged. } if *distance < min_distance { min_distance = *distance; @@ -138,13 +304,20 @@ impl, Y: Array1> AgglomerativeClusteri } } } + let (i, j_offset) = pairs; + + // Merge the two closest clusters (`i` and `j_offset`). let cluster1_indices = indices_mapping.remove(&i).unwrap(); let cluster2_indices = indices_mapping.remove(&j_offset).unwrap(); - cache.remove(&cluster1_indices); + cache.remove(&cluster1_indices); // Clear old cache entries. cache.remove(&cluster2_indices); + let mut combined_cluster_indices = cluster1_indices; combined_cluster_indices.extend(cluster2_indices); + + // Update the distance matrix. The new merged cluster will be stored at index `i`. + // Update distances from the new cluster `i` to all other clusters `j` where `j > i`. matrix[i] = (0..matrix[i].len()) .map(|j| { let j_offset = i + 1 + j; @@ -157,10 +330,12 @@ impl, Y: Array1> AgglomerativeClusteri &other_cluster_indices, ) } else { - 0.0 + 0.0 // This entry is now invalid as the other cluster was merged. } }) .collect(); + + // Update distances from all other clusters `g` to the new cluster `i` where `g < i`. for g in 0..i { let offset = i - g - 1; if let Some(other_cluster_indices) = indices_mapping.get(&g) { @@ -168,21 +343,25 @@ impl, Y: Array1> AgglomerativeClusteri &data, ¶meters.linkage, &mut cache, - &combined_cluster_indices, + &combined_cluster_indices, // Order does not matter for Ward's method. &other_cluster_indices, ) } } + // Add the new merged cluster to the mapping. indices_mapping.insert(i, combined_cluster_indices); } + + // Assign final labels based on the remaining clusters. let mut labels = vec![0; num_rows]; let mut sorted_keys: Vec<&usize> = indices_mapping.keys().collect(); - sorted_keys.sort(); + sorted_keys.sort(); // Sort for consistent label assignment. for (i, cluster) in sorted_keys.iter().enumerate() { for index in indices_mapping.get(cluster).unwrap() { - labels[*index] = i + labels[*index] = i; } } + Ok(Self { labels, _phantom_tx: PhantomData, @@ -192,3 +371,230 @@ impl, Y: Array1> AgglomerativeClusteri }) } } + +impl, Y: Array1> + UnsupervisedEstimator for AgglomerativeClustering +{ + fn fit(x: &X, parameters: AgglomerativeClusteringParameters) -> Result { + AgglomerativeClustering::fit(x, parameters) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::linalg::basic::matrix::DenseMatrix; + use std::collections::HashSet; + + fn assert_approx_eq(a: f32, b: f32) { + assert!( + (a - b).abs() < 1e-6, + "assertion failed: `(left !== right)` \n left: `{:?}`\n right: `{:?}`", + a, + b + ); + } + + #[test] + fn test_compute_cluster_variance() { + let data = DenseMatrix::from_2d_array(&[&[1.0, 1.0], &[3.0, 3.0], &[5.0, 5.0]]).unwrap(); + + // Variance of a single point is 0 + let variance1 = + AgglomerativeClustering::, Vec>::compute_cluster_variance( + &data, + &vec![0], + &vec![], + ); + assert_approx_eq(variance1, 0.0); + + // Variance of two points: [1,1] and [3,3] + // Mean is [2,2] + // Variance = ((1-2)^2 + (1-2)^2) + ((3-2)^2 + (3-2)^2) = (1+1) + (1+1) = 4.0 + let variance2 = + AgglomerativeClustering::, Vec>::compute_cluster_variance( + &data, + &vec![0], + &vec![1], + ); + assert_approx_eq(variance2, 4.0); + + // Variance of three points: [1,1], [3,3], [5,5] + // Mean is [3,3] + // Variance = ((1-3)^2+(1-3)^2) + ((3-3)^2+(3-3)^2) + ((5-3)^2+(5-3)^2) + // = (4+4) + (0+0) + (4+4) = 16.0 + let variance3 = + AgglomerativeClustering::, Vec>::compute_cluster_variance( + &data, + &vec![0, 1, 2], + &vec![], + ); + assert_approx_eq(variance3, 16.0); + } + + #[test] + fn test_compute_distance_ward() { + let data = DenseMatrix::from_2d_array(&[&[1.0, 1.0], &[3.0, 3.0]]).unwrap(); + let mut cache = HashMap::new(); + + let cluster1_indices = vec![0]; + let cluster2_indices = vec![1]; + + // var(c1) = 0, var(c2) = 0 + // var(c1 U c2) = 4.0 (from test above) + // distance = 4.0 - 0 - 0 = 4.0 + let distance = + AgglomerativeClustering::, Vec>::compute_distance( + &data, + &Linkage::Ward, + &mut cache, + &cluster1_indices, + &cluster2_indices, + ); + + assert_approx_eq(distance, 4.0); + // check that cache was populated + assert!(cache.contains_key(&cluster1_indices)); + assert!(cache.contains_key(&cluster2_indices)); + } + + #[test] + fn test_fit_simple_clusters() { + let data = DenseMatrix::from_2d_array(&[ + &[1.0, 2.0], // cluster 0 + &[1.5, 1.8], // cluster 0 + &[1.0, 0.6], // cluster 0 + &[8.0, 8.0], // cluster 1 + &[9.0, 11.0], // cluster 1 + &[8.5, 9.5], // cluster 1 + ]) + .unwrap(); + + let params = AgglomerativeClusteringParameters { + n_clusters: 2, + linkage: Linkage::Ward, + }; + + let result = + AgglomerativeClustering::, Vec>::fit(&data, params) + .unwrap(); + let labels = result.labels; + + assert_eq!(labels.len(), 6); + + let label_set_1 = labels[0]; + let label_set_2 = labels[3]; + + // Assert the two sets have different labels + assert_ne!(label_set_1, label_set_2); + + // Assert that the first three points belong to the same cluster + assert_eq!(labels[0], label_set_1); + assert_eq!(labels[1], label_set_1); + assert_eq!(labels[2], label_set_1); + + // Assert that the last three points belong to the same cluster + assert_eq!(labels[3], label_set_2); + assert_eq!(labels[4], label_set_2); + assert_eq!(labels[5], label_set_2); + } + + #[test] + fn test_n_clusters_parameter() { + let data = + DenseMatrix::from_2d_array(&[&[0.0], &[1.0], &[10.0], &[11.0], &[20.0], &[21.0]]) + .unwrap(); + + // Test with n_clusters = 3 + let params_3 = AgglomerativeClusteringParameters { + n_clusters: 3, + linkage: Linkage::Ward, + }; + let result_3 = + AgglomerativeClustering::, Vec>::fit(&data, params_3) + .unwrap(); + let unique_labels_3: HashSet = result_3.labels.into_iter().collect(); + assert_eq!(unique_labels_3.len(), 3); + + // Test with n_clusters = 1 + let params_1 = AgglomerativeClusteringParameters { + n_clusters: 1, + linkage: Linkage::Ward, + }; + let result_1 = + AgglomerativeClustering::, Vec>::fit(&data, params_1) + .unwrap(); + let unique_labels_1: HashSet = result_1.labels.into_iter().collect(); + assert_eq!(unique_labels_1.len(), 1); + } + + #[test] +fn test_fit_heavy_load_deterministic() { + let n_clusters = 5; + + // Define cluster properties: (center_x, center_y, num_points) + let cluster_definitions = vec![ + (0.0, 0.0, 10), + (100.0, 0.0, 20), + (0.0, 100.0, 15), + (100.0, 100.0, 25), + (50.0, -50.0, 5), + ]; + + // The expected sizes of the final clusters. + let mut expected_counts: Vec = + cluster_definitions.iter().map(|c| c.2).collect(); + expected_counts.sort_unstable(); + + let mut data_vec: Vec> = Vec::new(); + + // Generate data points for each cluster deterministically. + for (center_x, center_y, num_points) in cluster_definitions { + for i in 0..num_points { + // Add a small, predictable offset to each point based on its index. + // This creates a small, non-random spread around the center. + let offset = i as f32 * 0.1; + let x = center_x + offset; + let y = center_y + offset; + data_vec.push(vec![x, y]); + } + } + + // Convert to DenseMatrix + let data_refs: Vec<&[f32]> = data_vec.iter().map(|row| row.as_slice()).collect(); + let data = DenseMatrix::from_2d_array(&data_refs).unwrap(); + + // Run clustering + let params = AgglomerativeClusteringParameters { + n_clusters, + linkage: Linkage::Ward, + }; + let result = AgglomerativeClustering::, Vec>::fit(&data, params).unwrap(); + let labels = result.labels; + + // 1. Verify the number of distinct clusters found + let unique_labels: HashSet = labels.iter().cloned().collect(); + assert_eq!( + unique_labels.len(), + n_clusters, + "Expected {} distinct clusters, but found {}", + n_clusters, + unique_labels.len() + ); + + // 2. Verify the number of members in each cluster + let mut label_counts: HashMap = HashMap::new(); + for label in labels { + *label_counts.entry(label).or_insert(0) += 1; + } + + let mut actual_counts: Vec = label_counts.values().cloned().collect(); + actual_counts.sort_unstable(); + + assert_eq!( + actual_counts, expected_counts, + "Cluster sizes do not match expected values" + ); +} + +} diff --git a/src/cluster/kmeans.rs b/src/cluster/kmeans.rs index 2fade68f..76da5a39 100644 --- a/src/cluster/kmeans.rs +++ b/src/cluster/kmeans.rs @@ -413,6 +413,8 @@ impl, Y: Array1> KMeans } } + + #[cfg(test)] mod tests { use super::*; diff --git a/src/cluster/mod.rs b/src/cluster/mod.rs index c32a6366..e5ad188e 100644 --- a/src/cluster/mod.rs +++ b/src/cluster/mod.rs @@ -4,6 +4,6 @@ //! are more similar to other data points in the same group than those in other groups. In simple words, the aim is to segregate groups with similar traits and assign them into clusters. pub mod dbscan; +pub mod hierarchal; /// An iterative clustering algorithm that aims to find local maxima in each iteration. pub mod kmeans; -pub mod hierarchal; From 14071a5aeb46f419e62be29e9c069193e9663f9f Mon Sep 17 00:00:00 2001 From: bendeez Date: Tue, 24 Jun 2025 19:55:01 -0500 Subject: [PATCH 8/8] implemented hierarchal_clustering --- src/cluster/hierarchal.rs | 189 +++++++++++++++++++------------------- src/cluster/kmeans.rs | 2 - 2 files changed, 92 insertions(+), 99 deletions(-) diff --git a/src/cluster/hierarchal.rs b/src/cluster/hierarchal.rs index 8cad1a18..f347bb3f 100644 --- a/src/cluster/hierarchal.rs +++ b/src/cluster/hierarchal.rs @@ -53,15 +53,14 @@ //! * "An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., Chapter 10 //! * "Hierarchical Grouping to Optimize an Objective Function", Ward, J. H., Jr., 1963 //! * "Finding Groups in Data: An Introduction to Cluster Analysis", Kaufman, L., Rousseeuw, P.J., 1990 +use crate::api::UnsupervisedEstimator; use crate::{ error::Failed, linalg::basic::arrays::{Array1, Array2}, numbers::basenum::Number, }; -use crate::api::{UnsupervisedEstimator}; use std::collections::HashMap; -use std::{f32, iter::zip, marker::PhantomData}; -use std::collections::HashSet; +use std::{f64, iter::zip, marker::PhantomData}; /// Defines the linkage criterion to use for Agglomerative Clustering. /// @@ -139,7 +138,7 @@ impl, Y: Array1> AgglomerativeClusteri /// /// # Returns /// - /// The variance of the combined cluster as an `f32`. + /// The variance of the combined cluster as an `f64`. fn compute_cluster_variance( data: &X, cluster1_indices: &Vec, @@ -149,7 +148,7 @@ impl, Y: Array1> AgglomerativeClusteri let mut sum_row = vec![0 as f64; num_features]; // Sum up all feature vectors for the points in the given clusters - for cluster in vec![cluster1_indices, cluster2_indices] { + for cluster in [cluster1_indices, cluster2_indices] { for index in cluster { sum_row = zip(sum_row, data.get_row(*index).iterator(0)) .map(|(v, x)| v + x.to_f64().unwrap()) @@ -163,11 +162,11 @@ impl, Y: Array1> AgglomerativeClusteri let mut variance = 0.0; // Calculate the sum of squared distances from each point to the mean - for cluster in vec![cluster1_indices, cluster2_indices] { + for cluster in [cluster1_indices, cluster2_indices] { for index in cluster { let squared_distance: f64 = zip(data.get_row(*index).iterator(0), mean_row.iter()) .map(|(x, v)| (x.to_f64().unwrap() - *v).powf(2.0)) - .sum(); + .sum::(); variance += squared_distance; } } @@ -186,8 +185,8 @@ impl, Y: Array1> AgglomerativeClusteri /// /// # Returns /// - /// The distance between the two clusters as an `f32`. - fn compute_distance<'a>( + /// The distance between the two clusters as an `f64`. + fn compute_distance( data: &X, linkage: &Linkage, cache: &mut HashMap, f64>, @@ -205,7 +204,7 @@ impl, Y: Array1> AgglomerativeClusteri *variance } else { let cluster1_variance = - Self::compute_cluster_variance(&data, &cluster1_indices, &vec![]); + Self::compute_cluster_variance(data, cluster1_indices, &vec![]); cache.insert(cluster1_indices.clone(), cluster1_variance); cluster1_variance }; @@ -215,18 +214,17 @@ impl, Y: Array1> AgglomerativeClusteri *variance } else { let cluster2_variance = - Self::compute_cluster_variance(&data, &cluster2_indices, &vec![]); + Self::compute_cluster_variance(data, cluster2_indices, &vec![]); cache.insert(cluster2_indices.clone(), cluster2_variance); cluster2_variance }; // Compute variance of the merged cluster let both_cluster_variance = - Self::compute_cluster_variance(&data, &cluster1_indices, &cluster2_indices); + Self::compute_cluster_variance(data, cluster1_indices, cluster2_indices); // The increase in variance is the distance - let distance = both_cluster_variance - cluster1_variance - cluster2_variance; - distance + both_cluster_variance - cluster1_variance - cluster2_variance } } } @@ -246,11 +244,6 @@ impl, Y: Array1> AgglomerativeClusteri /// /// A `Result` which is `Ok` containing an `AgglomerativeClustering` instance with the /// final cluster labels, or an `Err` with a `Failed` error type if something goes wrong. - /// - - /// let clustering_result = AgglomerativeClustering::fit(&data, params).unwrap(); - /// // `clustering_result.labels` will contain the cluster assignment for each row of data. - /// ``` pub fn fit( data: &X, parameters: AgglomerativeClusteringParameters, @@ -323,11 +316,11 @@ impl, Y: Array1> AgglomerativeClusteri let j_offset = i + 1 + j; if let Some(other_cluster_indices) = indices_mapping.get(&j_offset) { Self::compute_distance( - &data, + data, ¶meters.linkage, &mut cache, &combined_cluster_indices, - &other_cluster_indices, + other_cluster_indices, ) } else { 0.0 // This entry is now invalid as the other cluster was merged. @@ -335,16 +328,17 @@ impl, Y: Array1> AgglomerativeClusteri }) .collect(); + #[allow(clippy::needless_range_loop)] // Update distances from all other clusters `g` to the new cluster `i` where `g < i`. for g in 0..i { let offset = i - g - 1; if let Some(other_cluster_indices) = indices_mapping.get(&g) { matrix[g][offset] = Self::compute_distance( - &data, + data, ¶meters.linkage, &mut cache, &combined_cluster_indices, // Order does not matter for Ward's method. - &other_cluster_indices, + other_cluster_indices, ) } } @@ -373,7 +367,8 @@ impl, Y: Array1> AgglomerativeClusteri } impl, Y: Array1> - UnsupervisedEstimator for AgglomerativeClustering + UnsupervisedEstimator + for AgglomerativeClustering { fn fit(x: &X, parameters: AgglomerativeClusteringParameters) -> Result { AgglomerativeClustering::fit(x, parameters) @@ -386,7 +381,7 @@ mod tests { use crate::linalg::basic::matrix::DenseMatrix; use std::collections::HashSet; - fn assert_approx_eq(a: f32, b: f32) { + fn assert_approx_eq(a: f64, b: f64) { assert!( (a - b).abs() < 1e-6, "assertion failed: `(left !== right)` \n left: `{:?}`\n right: `{:?}`", @@ -401,7 +396,7 @@ mod tests { // Variance of a single point is 0 let variance1 = - AgglomerativeClustering::, Vec>::compute_cluster_variance( + AgglomerativeClustering::, Vec>::compute_cluster_variance( &data, &vec![0], &vec![], @@ -412,7 +407,7 @@ mod tests { // Mean is [2,2] // Variance = ((1-2)^2 + (1-2)^2) + ((3-2)^2 + (3-2)^2) = (1+1) + (1+1) = 4.0 let variance2 = - AgglomerativeClustering::, Vec>::compute_cluster_variance( + AgglomerativeClustering::, Vec>::compute_cluster_variance( &data, &vec![0], &vec![1], @@ -424,7 +419,7 @@ mod tests { // Variance = ((1-3)^2+(1-3)^2) + ((3-3)^2+(3-3)^2) + ((5-3)^2+(5-3)^2) // = (4+4) + (0+0) + (4+4) = 16.0 let variance3 = - AgglomerativeClustering::, Vec>::compute_cluster_variance( + AgglomerativeClustering::, Vec>::compute_cluster_variance( &data, &vec![0, 1, 2], &vec![], @@ -444,7 +439,7 @@ mod tests { // var(c1 U c2) = 4.0 (from test above) // distance = 4.0 - 0 - 0 = 4.0 let distance = - AgglomerativeClustering::, Vec>::compute_distance( + AgglomerativeClustering::, Vec>::compute_distance( &data, &Linkage::Ward, &mut cache, @@ -476,7 +471,7 @@ mod tests { }; let result = - AgglomerativeClustering::, Vec>::fit(&data, params) + AgglomerativeClustering::, Vec>::fit(&data, params) .unwrap(); let labels = result.labels; @@ -511,7 +506,7 @@ mod tests { linkage: Linkage::Ward, }; let result_3 = - AgglomerativeClustering::, Vec>::fit(&data, params_3) + AgglomerativeClustering::, Vec>::fit(&data, params_3) .unwrap(); let unique_labels_3: HashSet = result_3.labels.into_iter().collect(); assert_eq!(unique_labels_3.len(), 3); @@ -522,79 +517,79 @@ mod tests { linkage: Linkage::Ward, }; let result_1 = - AgglomerativeClustering::, Vec>::fit(&data, params_1) + AgglomerativeClustering::, Vec>::fit(&data, params_1) .unwrap(); let unique_labels_1: HashSet = result_1.labels.into_iter().collect(); assert_eq!(unique_labels_1.len(), 1); } - #[test] -fn test_fit_heavy_load_deterministic() { - let n_clusters = 5; - - // Define cluster properties: (center_x, center_y, num_points) - let cluster_definitions = vec![ - (0.0, 0.0, 10), - (100.0, 0.0, 20), - (0.0, 100.0, 15), - (100.0, 100.0, 25), - (50.0, -50.0, 5), - ]; - - // The expected sizes of the final clusters. - let mut expected_counts: Vec = - cluster_definitions.iter().map(|c| c.2).collect(); - expected_counts.sort_unstable(); - - let mut data_vec: Vec> = Vec::new(); - - // Generate data points for each cluster deterministically. - for (center_x, center_y, num_points) in cluster_definitions { - for i in 0..num_points { - // Add a small, predictable offset to each point based on its index. - // This creates a small, non-random spread around the center. - let offset = i as f32 * 0.1; - let x = center_x + offset; - let y = center_y + offset; - data_vec.push(vec![x, y]); + #[test] + fn test_fit_heavy_load_deterministic() { + let n_clusters = 5; + + // Define cluster properties: (center_x, center_y, num_points) + let cluster_definitions = vec![ + (0.0, 0.0, 10), + (100.0, 0.0, 20), + (0.0, 100.0, 15), + (100.0, 100.0, 25), + (50.0, -50.0, 5), + ]; + + // The expected sizes of the final clusters. + let mut expected_counts: Vec = cluster_definitions.iter().map(|c| c.2).collect(); + expected_counts.sort_unstable(); + + let mut data_vec: Vec> = Vec::new(); + + // Generate data points for each cluster deterministically. + for (center_x, center_y, num_points) in cluster_definitions { + for i in 0..num_points { + // Add a small, predictable offset to each point based on its index. + // This creates a small, non-random spread around the center. + let offset = i as f64 * 0.1; + let x = center_x + offset; + let y = center_y + offset; + data_vec.push(vec![x, y]); + } } - } - // Convert to DenseMatrix - let data_refs: Vec<&[f32]> = data_vec.iter().map(|row| row.as_slice()).collect(); - let data = DenseMatrix::from_2d_array(&data_refs).unwrap(); - - // Run clustering - let params = AgglomerativeClusteringParameters { - n_clusters, - linkage: Linkage::Ward, - }; - let result = AgglomerativeClustering::, Vec>::fit(&data, params).unwrap(); - let labels = result.labels; - - // 1. Verify the number of distinct clusters found - let unique_labels: HashSet = labels.iter().cloned().collect(); - assert_eq!( - unique_labels.len(), - n_clusters, - "Expected {} distinct clusters, but found {}", - n_clusters, - unique_labels.len() - ); - - // 2. Verify the number of members in each cluster - let mut label_counts: HashMap = HashMap::new(); - for label in labels { - *label_counts.entry(label).or_insert(0) += 1; - } + // Convert to DenseMatrix + let data_refs: Vec<&[f64]> = data_vec.iter().map(|row| row.as_slice()).collect(); + let data = DenseMatrix::from_2d_array(&data_refs).unwrap(); + + // Run clustering + let params = AgglomerativeClusteringParameters { + n_clusters, + linkage: Linkage::Ward, + }; + let result = + AgglomerativeClustering::, Vec>::fit(&data, params) + .unwrap(); + let labels = result.labels; - let mut actual_counts: Vec = label_counts.values().cloned().collect(); - actual_counts.sort_unstable(); + // 1. Verify the number of distinct clusters found + let unique_labels: HashSet = labels.iter().cloned().collect(); + assert_eq!( + unique_labels.len(), + n_clusters, + "Expected {} distinct clusters, but found {}", + n_clusters, + unique_labels.len() + ); - assert_eq!( - actual_counts, expected_counts, - "Cluster sizes do not match expected values" - ); -} - + // 2. Verify the number of members in each cluster + let mut label_counts: HashMap = HashMap::new(); + for label in labels { + *label_counts.entry(label).or_insert(0) += 1; + } + + let mut actual_counts: Vec = label_counts.values().cloned().collect(); + actual_counts.sort_unstable(); + + assert_eq!( + actual_counts, expected_counts, + "Cluster sizes do not match expected values" + ); + } } diff --git a/src/cluster/kmeans.rs b/src/cluster/kmeans.rs index 76da5a39..2fade68f 100644 --- a/src/cluster/kmeans.rs +++ b/src/cluster/kmeans.rs @@ -413,8 +413,6 @@ impl, Y: Array1> KMeans } } - - #[cfg(test)] mod tests { use super::*;