Skip to content

Commit 82464f4

Browse files
Merge pull request #23 from smartcorelib/ridge
Ridge regression
2 parents b86c553 + 830a0d9 commit 82464f4

File tree

9 files changed

+627
-31
lines changed

9 files changed

+627
-31
lines changed

src/linalg/mod.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ pub mod nalgebra_bindings;
4848
pub mod ndarray_bindings;
4949
/// QR factorization that factors a matrix into a product of an orthogonal matrix and an upper triangular matrix.
5050
pub mod qr;
51+
pub mod stats;
5152
/// Singular value decomposition.
5253
pub mod svd;
5354

@@ -60,6 +61,7 @@ use cholesky::CholeskyDecomposableMatrix;
6061
use evd::EVDDecomposableMatrix;
6162
use lu::LUDecomposableMatrix;
6263
use qr::QRDecomposableMatrix;
64+
use stats::MatrixStats;
6365
use svd::SVDDecomposableMatrix;
6466

6567
/// Column or row vector
@@ -168,6 +170,30 @@ pub trait BaseVector<T: RealNumber>: Clone + Debug {
168170
///assert_eq!(a.unique(), vec![-7., -6., -2., 1., 2., 3., 4.]);
169171
/// ```
170172
fn unique(&self) -> Vec<T>;
173+
174+
/// Computes the arithmetic mean.
175+
fn mean(&self) -> T {
176+
self.sum() / T::from_usize(self.len()).unwrap()
177+
}
178+
/// Computes variance.
179+
fn var(&self) -> T {
180+
let n = self.len();
181+
182+
let mut mu = T::zero();
183+
let mut sum = T::zero();
184+
let div = T::from_usize(n).unwrap();
185+
for i in 0..n {
186+
let xi = self.get(i);
187+
mu += xi;
188+
sum += xi * xi;
189+
}
190+
mu /= div;
191+
sum / div - mu * mu
192+
}
193+
/// Computes the standard deviation.
194+
fn std(&self) -> T {
195+
self.var().sqrt()
196+
}
171197
}
172198

173199
/// Generic matrix type.
@@ -515,6 +541,7 @@ pub trait Matrix<T: RealNumber>:
515541
+ QRDecomposableMatrix<T>
516542
+ LUDecomposableMatrix<T>
517543
+ CholeskyDecomposableMatrix<T>
544+
+ MatrixStats<T>
518545
+ PartialEq
519546
+ Display
520547
{
@@ -550,3 +577,29 @@ impl<'a, T: RealNumber, M: BaseMatrix<T>> Iterator for RowIter<'a, T, M> {
550577
res
551578
}
552579
}
580+
581+
#[cfg(test)]
582+
mod tests {
583+
use crate::linalg::BaseVector;
584+
585+
#[test]
586+
fn mean() {
587+
let m = vec![1., 2., 3.];
588+
589+
assert_eq!(m.mean(), 2.0);
590+
}
591+
592+
#[test]
593+
fn std() {
594+
let m = vec![1., 2., 3.];
595+
596+
assert!((m.std() - 0.81f64).abs() < 1e-2);
597+
}
598+
599+
#[test]
600+
fn var() {
601+
let m = vec![1., 2., 3., 4.];
602+
603+
assert!((m.var() - 1.25f64).abs() < std::f64::EPSILON);
604+
}
605+
}

src/linalg/naive/dense_matrix.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use crate::linalg::cholesky::CholeskyDecomposableMatrix;
1111
use crate::linalg::evd::EVDDecomposableMatrix;
1212
use crate::linalg::lu::LUDecomposableMatrix;
1313
use crate::linalg::qr::QRDecomposableMatrix;
14+
use crate::linalg::stats::MatrixStats;
1415
use crate::linalg::svd::SVDDecomposableMatrix;
1516
use crate::linalg::Matrix;
1617
pub use crate::linalg::{BaseMatrix, BaseVector};
@@ -443,6 +444,8 @@ impl<T: RealNumber> LUDecomposableMatrix<T> for DenseMatrix<T> {}
443444

444445
impl<T: RealNumber> CholeskyDecomposableMatrix<T> for DenseMatrix<T> {}
445446

447+
impl<T: RealNumber> MatrixStats<T> for DenseMatrix<T> {}
448+
446449
impl<T: RealNumber> Matrix<T> for DenseMatrix<T> {}
447450

448451
impl<T: RealNumber> PartialEq for DenseMatrix<T> {

src/linalg/nalgebra_bindings.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ use crate::linalg::cholesky::CholeskyDecomposableMatrix;
4646
use crate::linalg::evd::EVDDecomposableMatrix;
4747
use crate::linalg::lu::LUDecomposableMatrix;
4848
use crate::linalg::qr::QRDecomposableMatrix;
49+
use crate::linalg::stats::MatrixStats;
4950
use crate::linalg::svd::SVDDecomposableMatrix;
5051
use crate::linalg::Matrix as SmartCoreMatrix;
5152
use crate::linalg::{BaseMatrix, BaseVector};
@@ -546,6 +547,11 @@ impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Su
546547
{
547548
}
548549

550+
impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
551+
MatrixStats<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
552+
{
553+
}
554+
549555
impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
550556
SmartCoreMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
551557
{

src/linalg/ndarray_bindings.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ use crate::linalg::cholesky::CholeskyDecomposableMatrix;
5353
use crate::linalg::evd::EVDDecomposableMatrix;
5454
use crate::linalg::lu::LUDecomposableMatrix;
5555
use crate::linalg::qr::QRDecomposableMatrix;
56+
use crate::linalg::stats::MatrixStats;
5657
use crate::linalg::svd::SVDDecomposableMatrix;
5758
use crate::linalg::Matrix;
5859
use crate::linalg::{BaseMatrix, BaseVector};
@@ -496,6 +497,11 @@ impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssi
496497
{
497498
}
498499

500+
impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>
501+
MatrixStats<T> for ArrayBase<OwnedRepr<T>, Ix2>
502+
{
503+
}
504+
499505
impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> Matrix<T>
500506
for ArrayBase<OwnedRepr<T>, Ix2>
501507
{

src/linalg/stats.rs

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
//! # Various Statistical Methods
2+
//!
3+
//! This module provides reference implementations for various statistical functions.
4+
//! Concrete implementations of the `BaseMatrix` trait are free to override these methods for better performance.
5+
6+
use crate::linalg::BaseMatrix;
7+
use crate::math::num::RealNumber;
8+
9+
/// Defines baseline implementations for various statistical functions
10+
pub trait MatrixStats<T: RealNumber>: BaseMatrix<T> {
11+
/// Computes the arithmetic mean along the specified axis.
12+
fn mean(&self, axis: u8) -> Vec<T> {
13+
let (n, m) = match axis {
14+
0 => {
15+
let (n, m) = self.shape();
16+
(m, n)
17+
}
18+
_ => self.shape(),
19+
};
20+
21+
let mut x: Vec<T> = vec![T::zero(); n];
22+
23+
let div = T::from_usize(m).unwrap();
24+
25+
for i in 0..n {
26+
for j in 0..m {
27+
x[i] += match axis {
28+
0 => self.get(j, i),
29+
_ => self.get(i, j),
30+
};
31+
}
32+
x[i] /= div;
33+
}
34+
35+
x
36+
}
37+
38+
/// Computes variance along the specified axis.
39+
fn var(&self, axis: u8) -> Vec<T> {
40+
let (n, m) = match axis {
41+
0 => {
42+
let (n, m) = self.shape();
43+
(m, n)
44+
}
45+
_ => self.shape(),
46+
};
47+
48+
let mut x: Vec<T> = vec![T::zero(); n];
49+
50+
let div = T::from_usize(m).unwrap();
51+
52+
for i in 0..n {
53+
let mut mu = T::zero();
54+
let mut sum = T::zero();
55+
for j in 0..m {
56+
let a = match axis {
57+
0 => self.get(j, i),
58+
_ => self.get(i, j),
59+
};
60+
mu += a;
61+
sum += a * a;
62+
}
63+
mu /= div;
64+
x[i] = sum / div - mu * mu;
65+
}
66+
67+
x
68+
}
69+
70+
/// Computes the standard deviation along the specified axis.
71+
fn std(&self, axis: u8) -> Vec<T> {
72+
let mut x = self.var(axis);
73+
74+
let n = match axis {
75+
0 => self.shape().1,
76+
_ => self.shape().0,
77+
};
78+
79+
for i in 0..n {
80+
x[i] = x[i].sqrt();
81+
}
82+
83+
x
84+
}
85+
86+
/// standardize values by removing the mean and scaling to unit variance
87+
fn scale_mut(&mut self, mean: &Vec<T>, std: &Vec<T>, axis: u8) {
88+
let (n, m) = match axis {
89+
0 => {
90+
let (n, m) = self.shape();
91+
(m, n)
92+
}
93+
_ => self.shape(),
94+
};
95+
96+
for i in 0..n {
97+
for j in 0..m {
98+
match axis {
99+
0 => self.set(j, i, (self.get(j, i) - mean[i]) / std[i]),
100+
_ => self.set(i, j, (self.get(i, j) - mean[i]) / std[i]),
101+
}
102+
}
103+
}
104+
}
105+
}
106+
107+
#[cfg(test)]
108+
mod tests {
109+
use super::*;
110+
use crate::linalg::naive::dense_matrix::DenseMatrix;
111+
use crate::linalg::BaseVector;
112+
113+
#[test]
114+
fn mean() {
115+
let m = DenseMatrix::from_2d_array(&[
116+
&[1., 2., 3., 1., 2.],
117+
&[4., 5., 6., 3., 4.],
118+
&[7., 8., 9., 5., 6.],
119+
]);
120+
let expected_0 = vec![4., 5., 6., 3., 4.];
121+
let expected_1 = vec![1.8, 4.4, 7.];
122+
123+
assert_eq!(m.mean(0), expected_0);
124+
assert_eq!(m.mean(1), expected_1);
125+
}
126+
127+
#[test]
128+
fn std() {
129+
let m = DenseMatrix::from_2d_array(&[
130+
&[1., 2., 3., 1., 2.],
131+
&[4., 5., 6., 3., 4.],
132+
&[7., 8., 9., 5., 6.],
133+
]);
134+
let expected_0 = vec![2.44, 2.44, 2.44, 1.63, 1.63];
135+
let expected_1 = vec![0.74, 1.01, 1.41];
136+
137+
assert!(m.std(0).approximate_eq(&expected_0, 1e-2));
138+
assert!(m.std(1).approximate_eq(&expected_1, 1e-2));
139+
}
140+
141+
#[test]
142+
fn var() {
143+
let m = DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]);
144+
let expected_0 = vec![4., 4., 4., 4.];
145+
let expected_1 = vec![1.25, 1.25];
146+
147+
assert!(m.var(0).approximate_eq(&expected_0, std::f64::EPSILON));
148+
assert!(m.var(1).approximate_eq(&expected_1, std::f64::EPSILON));
149+
}
150+
151+
#[test]
152+
fn scale() {
153+
let mut m = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
154+
let expected_0 = DenseMatrix::from_2d_array(&[&[-1., -1., -1.], &[1., 1., 1.]]);
155+
let expected_1 = DenseMatrix::from_2d_array(&[&[-1.22, 0.0, 1.22], &[-1.22, 0.0, 1.22]]);
156+
157+
{
158+
let mut m = m.clone();
159+
m.scale_mut(&m.mean(0), &m.std(0), 0);
160+
assert!(m.approximate_eq(&expected_0, std::f32::EPSILON));
161+
}
162+
163+
m.scale_mut(&m.mean(1), &m.std(1), 1);
164+
assert!(m.approximate_eq(&expected_1, 1e-2));
165+
}
166+
}

src/linear/linear_regression.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,8 @@ impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
154154
}
155155

156156
/// Get estimates regression coefficients
157-
pub fn coefficients(&self) -> M {
158-
self.coefficients.clone()
157+
pub fn coefficients(&self) -> &M {
158+
&self.coefficients
159159
}
160160

161161
/// Get estimate of intercept

0 commit comments

Comments
 (0)