|
| 1 | +//! # Dimensionality reduction using SVD |
| 2 | +//! |
| 3 | +//! Similar to [`PCA`](../pca/index.html), SVD is a technique that can be used to reduce the number of input variables _p_ to a smaller number _k_, while preserving |
| 4 | +//! the most important structure or relationships between the variables observed in the data. |
| 5 | +//! |
| 6 | +//! Contrary to PCA, SVD does not center the data before computing the singular value decomposition. |
| 7 | +//! |
| 8 | +//! Example: |
| 9 | +//! ``` |
| 10 | +//! use smartcore::linalg::naive::dense_matrix::*; |
| 11 | +//! use smartcore::decomposition::svd::*; |
| 12 | +//! |
| 13 | +//! // Iris data |
| 14 | +//! let iris = DenseMatrix::from_2d_array(&[ |
| 15 | +//! &[5.1, 3.5, 1.4, 0.2], |
| 16 | +//! &[4.9, 3.0, 1.4, 0.2], |
| 17 | +//! &[4.7, 3.2, 1.3, 0.2], |
| 18 | +//! &[4.6, 3.1, 1.5, 0.2], |
| 19 | +//! &[5.0, 3.6, 1.4, 0.2], |
| 20 | +//! &[5.4, 3.9, 1.7, 0.4], |
| 21 | +//! &[4.6, 3.4, 1.4, 0.3], |
| 22 | +//! &[5.0, 3.4, 1.5, 0.2], |
| 23 | +//! &[4.4, 2.9, 1.4, 0.2], |
| 24 | +//! &[4.9, 3.1, 1.5, 0.1], |
| 25 | +//! &[7.0, 3.2, 4.7, 1.4], |
| 26 | +//! &[6.4, 3.2, 4.5, 1.5], |
| 27 | +//! &[6.9, 3.1, 4.9, 1.5], |
| 28 | +//! &[5.5, 2.3, 4.0, 1.3], |
| 29 | +//! &[6.5, 2.8, 4.6, 1.5], |
| 30 | +//! &[5.7, 2.8, 4.5, 1.3], |
| 31 | +//! &[6.3, 3.3, 4.7, 1.6], |
| 32 | +//! &[4.9, 2.4, 3.3, 1.0], |
| 33 | +//! &[6.6, 2.9, 4.6, 1.3], |
| 34 | +//! &[5.2, 2.7, 3.9, 1.4], |
| 35 | +//! ]); |
| 36 | +//! |
| 37 | +//! let svd = SVD::fit(&iris, 2, Default::default()).unwrap(); // Reduce number of features to 2 |
| 38 | +//! |
| 39 | +//! let iris_reduced = svd.transform(&iris).unwrap(); |
| 40 | +//! |
| 41 | +//! ``` |
| 42 | +//! |
| 43 | +//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> |
| 44 | +//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> |
| 45 | +use std::fmt::Debug; |
| 46 | +use std::marker::PhantomData; |
| 47 | + |
| 48 | +use serde::{Deserialize, Serialize}; |
| 49 | + |
| 50 | +use crate::error::Failed; |
| 51 | +use crate::linalg::Matrix; |
| 52 | +use crate::math::num::RealNumber; |
| 53 | + |
| 54 | +/// SVD |
| 55 | +#[derive(Serialize, Deserialize, Debug)] |
| 56 | +pub struct SVD<T: RealNumber, M: Matrix<T>> { |
| 57 | + components: M, |
| 58 | + phantom: PhantomData<T>, |
| 59 | +} |
| 60 | + |
| 61 | +impl<T: RealNumber, M: Matrix<T>> PartialEq for SVD<T, M> { |
| 62 | + fn eq(&self, other: &Self) -> bool { |
| 63 | + self.components |
| 64 | + .approximate_eq(&other.components, T::from_f64(1e-8).unwrap()) |
| 65 | + } |
| 66 | +} |
| 67 | + |
| 68 | +#[derive(Debug, Clone)] |
| 69 | +/// SVD parameters |
| 70 | +pub struct SVDParameters {} |
| 71 | + |
| 72 | +impl Default for SVDParameters { |
| 73 | + fn default() -> Self { |
| 74 | + SVDParameters {} |
| 75 | + } |
| 76 | +} |
| 77 | + |
| 78 | +impl<T: RealNumber, M: Matrix<T>> SVD<T, M> { |
| 79 | + /// Fits SVD to your data. |
| 80 | + /// * `data` - _NxM_ matrix with _N_ observations and _M_ features in each observation. |
| 81 | + /// * `n_components` - number of components to keep. |
| 82 | + /// * `parameters` - other parameters, use `Default::default()` to set parameters to default values. |
| 83 | + pub fn fit(x: &M, n_components: usize, _: SVDParameters) -> Result<SVD<T, M>, Failed> { |
| 84 | + let (_, p) = x.shape(); |
| 85 | + |
| 86 | + if n_components >= p { |
| 87 | + return Err(Failed::fit(&format!( |
| 88 | + "Number of components, n_components should be < number of attributes ({})", |
| 89 | + p |
| 90 | + ))); |
| 91 | + } |
| 92 | + |
| 93 | + let svd = x.svd()?; |
| 94 | + |
| 95 | + let components = svd.V.slice(0..p, 0..n_components); |
| 96 | + |
| 97 | + Ok(SVD { |
| 98 | + components, |
| 99 | + phantom: PhantomData, |
| 100 | + }) |
| 101 | + } |
| 102 | + |
| 103 | + /// Run dimensionality reduction for `x` |
| 104 | + /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. |
| 105 | + pub fn transform(&self, x: &M) -> Result<M, Failed> { |
| 106 | + let (n, p) = x.shape(); |
| 107 | + let (p_c, k) = self.components.shape(); |
| 108 | + if p_c != p { |
| 109 | + return Err(Failed::transform(&format!( |
| 110 | + "Can not transform a {}x{} matrix into {}x{} matrix, incorrect input dimentions", |
| 111 | + n, p, n, k |
| 112 | + ))); |
| 113 | + } |
| 114 | + |
| 115 | + Ok(x.matmul(&self.components)) |
| 116 | + } |
| 117 | + |
| 118 | + /// Get a projection matrix |
| 119 | + pub fn components(&self) -> &M { |
| 120 | + &self.components |
| 121 | + } |
| 122 | +} |
| 123 | + |
| 124 | +#[cfg(test)] |
| 125 | +mod tests { |
| 126 | + use super::*; |
| 127 | + use crate::linalg::naive::dense_matrix::*; |
| 128 | + |
| 129 | + #[test] |
| 130 | + fn svd_decompose() { |
| 131 | + // https://stat.ethz.ch/R-manual/R-devel/library/datasets/html/USArrests.html |
| 132 | + let x = DenseMatrix::from_2d_array(&[ |
| 133 | + &[13.2, 236.0, 58.0, 21.2], |
| 134 | + &[10.0, 263.0, 48.0, 44.5], |
| 135 | + &[8.1, 294.0, 80.0, 31.0], |
| 136 | + &[8.8, 190.0, 50.0, 19.5], |
| 137 | + &[9.0, 276.0, 91.0, 40.6], |
| 138 | + &[7.9, 204.0, 78.0, 38.7], |
| 139 | + &[3.3, 110.0, 77.0, 11.1], |
| 140 | + &[5.9, 238.0, 72.0, 15.8], |
| 141 | + &[15.4, 335.0, 80.0, 31.9], |
| 142 | + &[17.4, 211.0, 60.0, 25.8], |
| 143 | + &[5.3, 46.0, 83.0, 20.2], |
| 144 | + &[2.6, 120.0, 54.0, 14.2], |
| 145 | + &[10.4, 249.0, 83.0, 24.0], |
| 146 | + &[7.2, 113.0, 65.0, 21.0], |
| 147 | + &[2.2, 56.0, 57.0, 11.3], |
| 148 | + &[6.0, 115.0, 66.0, 18.0], |
| 149 | + &[9.7, 109.0, 52.0, 16.3], |
| 150 | + &[15.4, 249.0, 66.0, 22.2], |
| 151 | + &[2.1, 83.0, 51.0, 7.8], |
| 152 | + &[11.3, 300.0, 67.0, 27.8], |
| 153 | + &[4.4, 149.0, 85.0, 16.3], |
| 154 | + &[12.1, 255.0, 74.0, 35.1], |
| 155 | + &[2.7, 72.0, 66.0, 14.9], |
| 156 | + &[16.1, 259.0, 44.0, 17.1], |
| 157 | + &[9.0, 178.0, 70.0, 28.2], |
| 158 | + &[6.0, 109.0, 53.0, 16.4], |
| 159 | + &[4.3, 102.0, 62.0, 16.5], |
| 160 | + &[12.2, 252.0, 81.0, 46.0], |
| 161 | + &[2.1, 57.0, 56.0, 9.5], |
| 162 | + &[7.4, 159.0, 89.0, 18.8], |
| 163 | + &[11.4, 285.0, 70.0, 32.1], |
| 164 | + &[11.1, 254.0, 86.0, 26.1], |
| 165 | + &[13.0, 337.0, 45.0, 16.1], |
| 166 | + &[0.8, 45.0, 44.0, 7.3], |
| 167 | + &[7.3, 120.0, 75.0, 21.4], |
| 168 | + &[6.6, 151.0, 68.0, 20.0], |
| 169 | + &[4.9, 159.0, 67.0, 29.3], |
| 170 | + &[6.3, 106.0, 72.0, 14.9], |
| 171 | + &[3.4, 174.0, 87.0, 8.3], |
| 172 | + &[14.4, 279.0, 48.0, 22.5], |
| 173 | + &[3.8, 86.0, 45.0, 12.8], |
| 174 | + &[13.2, 188.0, 59.0, 26.9], |
| 175 | + &[12.7, 201.0, 80.0, 25.5], |
| 176 | + &[3.2, 120.0, 80.0, 22.9], |
| 177 | + &[2.2, 48.0, 32.0, 11.2], |
| 178 | + &[8.5, 156.0, 63.0, 20.7], |
| 179 | + &[4.0, 145.0, 73.0, 26.2], |
| 180 | + &[5.7, 81.0, 39.0, 9.3], |
| 181 | + &[2.6, 53.0, 66.0, 10.8], |
| 182 | + &[6.8, 161.0, 60.0, 15.6], |
| 183 | + ]); |
| 184 | + |
| 185 | + let expected = DenseMatrix::from_2d_array(&[ |
| 186 | + &[243.54655757, -18.76673788], |
| 187 | + &[268.36802004, -33.79304302], |
| 188 | + &[305.93972467, -15.39087376], |
| 189 | + &[197.28420365, -11.66808306], |
| 190 | + &[293.43187394, 1.91163633], |
| 191 | + ]); |
| 192 | + let svd = SVD::fit(&x, 2, Default::default()).unwrap(); |
| 193 | + |
| 194 | + let x_transformed = svd.transform(&x).unwrap(); |
| 195 | + |
| 196 | + assert_eq!(svd.components.shape(), (x.shape().1, 2)); |
| 197 | + |
| 198 | + assert!(x_transformed |
| 199 | + .slice(0..5, 0..2) |
| 200 | + .approximate_eq(&expected, 1e-4)); |
| 201 | + } |
| 202 | + |
| 203 | + #[test] |
| 204 | + fn serde() { |
| 205 | + let iris = DenseMatrix::from_2d_array(&[ |
| 206 | + &[5.1, 3.5, 1.4, 0.2], |
| 207 | + &[4.9, 3.0, 1.4, 0.2], |
| 208 | + &[4.7, 3.2, 1.3, 0.2], |
| 209 | + &[4.6, 3.1, 1.5, 0.2], |
| 210 | + &[5.0, 3.6, 1.4, 0.2], |
| 211 | + &[5.4, 3.9, 1.7, 0.4], |
| 212 | + &[4.6, 3.4, 1.4, 0.3], |
| 213 | + &[5.0, 3.4, 1.5, 0.2], |
| 214 | + &[4.4, 2.9, 1.4, 0.2], |
| 215 | + &[4.9, 3.1, 1.5, 0.1], |
| 216 | + &[7.0, 3.2, 4.7, 1.4], |
| 217 | + &[6.4, 3.2, 4.5, 1.5], |
| 218 | + &[6.9, 3.1, 4.9, 1.5], |
| 219 | + &[5.5, 2.3, 4.0, 1.3], |
| 220 | + &[6.5, 2.8, 4.6, 1.5], |
| 221 | + &[5.7, 2.8, 4.5, 1.3], |
| 222 | + &[6.3, 3.3, 4.7, 1.6], |
| 223 | + &[4.9, 2.4, 3.3, 1.0], |
| 224 | + &[6.6, 2.9, 4.6, 1.3], |
| 225 | + &[5.2, 2.7, 3.9, 1.4], |
| 226 | + ]); |
| 227 | + |
| 228 | + let svd = SVD::fit(&iris, 2, Default::default()).unwrap(); |
| 229 | + |
| 230 | + let deserialized_svd: SVD<f64, DenseMatrix<f64>> = |
| 231 | + serde_json::from_str(&serde_json::to_string(&svd).unwrap()).unwrap(); |
| 232 | + |
| 233 | + assert_eq!(svd, deserialized_svd); |
| 234 | + } |
| 235 | +} |
0 commit comments