Skip to content

Commit 4efad85

Browse files
Merge pull request #21 from smartcorelib/cholesky
feat: adds Cholesky matrix decomposition
2 parents 7007e06 + b8fea67 commit 4efad85

File tree

6 files changed

+227
-0
lines changed

6 files changed

+227
-0
lines changed

src/error/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ pub enum FailedError {
2424
FindFailed,
2525
/// Can't decompose a matrix
2626
DecompositionFailed,
27+
/// Can't solve for x
28+
SolutionFailed,
2729
}
2830

2931
impl Failed {
@@ -87,6 +89,7 @@ impl fmt::Display for FailedError {
8789
FailedError::TransformFailed => "Transform failed",
8890
FailedError::FindFailed => "Find failed",
8991
FailedError::DecompositionFailed => "Decomposition failed",
92+
FailedError::SolutionFailed => "Can't find solution",
9093
};
9194
write!(f, "{}", failed_err_str)
9295
}

src/linalg/cholesky.rs

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
//! # Cholesky Decomposition
2+
//!
3+
//! every positive definite matrix \\(A \in R^{n \times n}\\) can be factored as
4+
//!
5+
//! \\[A = R^TR\\]
6+
//!
7+
//! where \\(R\\) is upper triangular matrix with positive diagonal elements
8+
//!
9+
//! Example:
10+
//! ```
11+
//! use smartcore::linalg::naive::dense_matrix::*;
12+
//! use crate::smartcore::linalg::cholesky::*;
13+
//!
14+
//! let A = DenseMatrix::from_2d_array(&[
15+
//! &[25., 15., -5.],
16+
//! &[15., 18., 0.],
17+
//! &[-5., 0., 11.]
18+
//! ]);
19+
//!
20+
//! let cholesky = A.cholesky().unwrap();
21+
//! let lower_triangular: DenseMatrix<f64> = cholesky.L();
22+
//! let upper_triangular: DenseMatrix<f64> = cholesky.U();
23+
//! ```
24+
//!
25+
//! ## References:
26+
//! * ["No bullshit guide to linear algebra", Ivan Savov, 2016, 7.6 Matrix decompositions](https://minireference.com/)
27+
//! * ["Numerical Recipes: The Art of Scientific Computing", Press W.H., Teukolsky S.A., Vetterling W.T, Flannery B.P, 3rd ed., 2.9 Cholesky Decomposition](http://numerical.recipes/)
28+
//!
29+
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
30+
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
31+
#![allow(non_snake_case)]
32+
33+
use std::fmt::Debug;
34+
use std::marker::PhantomData;
35+
36+
use crate::error::{Failed, FailedError};
37+
use crate::linalg::BaseMatrix;
38+
use crate::math::num::RealNumber;
39+
40+
#[derive(Debug, Clone)]
41+
/// Results of Cholesky decomposition.
42+
pub struct Cholesky<T: RealNumber, M: BaseMatrix<T>> {
43+
R: M,
44+
t: PhantomData<T>,
45+
}
46+
47+
impl<T: RealNumber, M: BaseMatrix<T>> Cholesky<T, M> {
48+
pub(crate) fn new(R: M) -> Cholesky<T, M> {
49+
Cholesky {
50+
R: R,
51+
t: PhantomData,
52+
}
53+
}
54+
55+
/// Get lower triangular matrix.
56+
pub fn L(&self) -> M {
57+
let (n, _) = self.R.shape();
58+
let mut R = M::zeros(n, n);
59+
60+
for i in 0..n {
61+
for j in 0..n {
62+
if j <= i {
63+
R.set(i, j, self.R.get(i, j));
64+
}
65+
}
66+
}
67+
R
68+
}
69+
70+
/// Get upper triangular matrix.
71+
pub fn U(&self) -> M {
72+
let (n, _) = self.R.shape();
73+
let mut R = M::zeros(n, n);
74+
75+
for i in 0..n {
76+
for j in 0..n {
77+
if j <= i {
78+
R.set(j, i, self.R.get(i, j));
79+
}
80+
}
81+
}
82+
R
83+
}
84+
85+
/// Solves Ax = b
86+
pub(crate) fn solve(&self, mut b: M) -> Result<M, Failed> {
87+
let (bn, m) = b.shape();
88+
let (rn, _) = self.R.shape();
89+
90+
if bn != rn {
91+
return Err(Failed::because(
92+
FailedError::SolutionFailed,
93+
&format!("Can't solve Ax = b for x. Number of rows in b != number of rows in R."),
94+
));
95+
}
96+
97+
for k in 0..bn {
98+
for j in 0..m {
99+
for i in 0..k {
100+
b.sub_element_mut(k, j, b.get(i, j) * self.R.get(k, i));
101+
}
102+
b.div_element_mut(k, j, self.R.get(k, k));
103+
}
104+
}
105+
106+
for k in (0..bn).rev() {
107+
for j in 0..m {
108+
for i in k + 1..bn {
109+
b.sub_element_mut(k, j, b.get(i, j) * self.R.get(i, k));
110+
}
111+
b.div_element_mut(k, j, self.R.get(k, k));
112+
}
113+
}
114+
Ok(b)
115+
}
116+
}
117+
118+
/// Trait that implements Cholesky decomposition routine for any matrix.
119+
pub trait CholeskyDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
120+
/// Compute the Cholesky decomposition of a matrix.
121+
fn cholesky(&self) -> Result<Cholesky<T, Self>, Failed> {
122+
self.clone().cholesky_mut()
123+
}
124+
125+
/// Compute the Cholesky decomposition of a matrix. The input matrix
126+
/// will be used for factorization.
127+
fn cholesky_mut(mut self) -> Result<Cholesky<T, Self>, Failed> {
128+
let (m, n) = self.shape();
129+
130+
if m != n {
131+
return Err(Failed::because(
132+
FailedError::DecompositionFailed,
133+
&format!("Can't do Cholesky decomposition on a non-square matrix"),
134+
));
135+
}
136+
137+
for j in 0..n {
138+
let mut d = T::zero();
139+
for k in 0..j {
140+
let mut s = T::zero();
141+
for i in 0..k {
142+
s += self.get(k, i) * self.get(j, i);
143+
}
144+
s = (self.get(j, k) - s) / self.get(k, k);
145+
self.set(j, k, s);
146+
d = d + s * s;
147+
}
148+
d = self.get(j, j) - d;
149+
150+
if d < T::zero() {
151+
return Err(Failed::because(
152+
FailedError::DecompositionFailed,
153+
&format!("The matrix is not positive definite."),
154+
));
155+
}
156+
157+
self.set(j, j, d.sqrt());
158+
}
159+
160+
Ok(Cholesky::new(self))
161+
}
162+
163+
/// Solves Ax = b
164+
fn cholesky_solve_mut(self, b: Self) -> Result<Self, Failed> {
165+
self.cholesky_mut().and_then(|qr| qr.solve(b))
166+
}
167+
}
168+
169+
#[cfg(test)]
170+
mod tests {
171+
use super::*;
172+
use crate::linalg::naive::dense_matrix::*;
173+
174+
#[test]
175+
fn cholesky_decompose() {
176+
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
177+
let l =
178+
DenseMatrix::from_2d_array(&[&[5.0, 0.0, 0.0], &[3.0, 3.0, 0.0], &[-1.0, 1.0, 3.0]]);
179+
let u =
180+
DenseMatrix::from_2d_array(&[&[5.0, 3.0, -1.0], &[0.0, 3.0, 1.0], &[0.0, 0.0, 3.0]]);
181+
let cholesky = a.cholesky().unwrap();
182+
183+
assert!(cholesky.L().abs().approximate_eq(&l.abs(), 1e-4));
184+
assert!(cholesky.U().abs().approximate_eq(&u.abs(), 1e-4));
185+
assert!(cholesky
186+
.L()
187+
.matmul(&cholesky.U())
188+
.abs()
189+
.approximate_eq(&a.abs(), 1e-4));
190+
}
191+
192+
#[test]
193+
fn cholesky_solve_mut() {
194+
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
195+
let b = DenseMatrix::from_2d_array(&[&[40., 51., 28.]]);
196+
let expected = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]);
197+
198+
let cholesky = a.cholesky().unwrap();
199+
200+
assert!(cholesky
201+
.solve(b.transpose())
202+
.unwrap()
203+
.transpose()
204+
.approximate_eq(&expected, 1e-4));
205+
}
206+
}

src/linalg/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
//! let u: DenseMatrix<f64> = svd.U;
3434
//! ```
3535
36+
pub mod cholesky;
3637
/// The matrix is represented in terms of its eigenvalues and eigenvectors.
3738
pub mod evd;
3839
/// Factors a matrix as the product of a lower triangular matrix and an upper triangular matrix.
@@ -55,6 +56,7 @@ use std::marker::PhantomData;
5556
use std::ops::Range;
5657

5758
use crate::math::num::RealNumber;
59+
use cholesky::CholeskyDecomposableMatrix;
5860
use evd::EVDDecomposableMatrix;
5961
use lu::LUDecomposableMatrix;
6062
use qr::QRDecomposableMatrix;
@@ -507,6 +509,7 @@ pub trait Matrix<T: RealNumber>:
507509
+ EVDDecomposableMatrix<T>
508510
+ QRDecomposableMatrix<T>
509511
+ LUDecomposableMatrix<T>
512+
+ CholeskyDecomposableMatrix<T>
510513
+ PartialEq
511514
+ Display
512515
{

src/linalg/naive/dense_matrix.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use serde::de::{Deserializer, MapAccess, SeqAccess, Visitor};
88
use serde::ser::{SerializeStruct, Serializer};
99
use serde::{Deserialize, Serialize};
1010

11+
use crate::linalg::cholesky::CholeskyDecomposableMatrix;
1112
use crate::linalg::evd::EVDDecomposableMatrix;
1213
use crate::linalg::lu::LUDecomposableMatrix;
1314
use crate::linalg::qr::QRDecomposableMatrix;
@@ -442,6 +443,8 @@ impl<T: RealNumber> QRDecomposableMatrix<T> for DenseMatrix<T> {}
442443

443444
impl<T: RealNumber> LUDecomposableMatrix<T> for DenseMatrix<T> {}
444445

446+
impl<T: RealNumber> CholeskyDecomposableMatrix<T> for DenseMatrix<T> {}
447+
445448
impl<T: RealNumber> Matrix<T> for DenseMatrix<T> {}
446449

447450
impl<T: RealNumber> PartialEq for DenseMatrix<T> {

src/linalg/nalgebra_bindings.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ use std::ops::{AddAssign, DivAssign, MulAssign, Range, SubAssign};
4242

4343
use nalgebra::{DMatrix, Dynamic, Matrix, MatrixMN, RowDVector, Scalar, VecStorage, U1};
4444

45+
use crate::linalg::cholesky::CholeskyDecomposableMatrix;
4546
use crate::linalg::evd::EVDDecomposableMatrix;
4647
use crate::linalg::lu::LUDecomposableMatrix;
4748
use crate::linalg::qr::QRDecomposableMatrix;
@@ -544,6 +545,11 @@ impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Su
544545
{
545546
}
546547

548+
impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
549+
CholeskyDecomposableMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
550+
{
551+
}
552+
547553
impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
548554
SmartCoreMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
549555
{

src/linalg/ndarray_bindings.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ use std::ops::SubAssign;
4949
use ndarray::ScalarOperand;
5050
use ndarray::{s, stack, Array, ArrayBase, Axis, Ix1, Ix2, OwnedRepr};
5151

52+
use crate::linalg::cholesky::CholeskyDecomposableMatrix;
5253
use crate::linalg::evd::EVDDecomposableMatrix;
5354
use crate::linalg::lu::LUDecomposableMatrix;
5455
use crate::linalg::qr::QRDecomposableMatrix;
@@ -494,6 +495,11 @@ impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssi
494495
{
495496
}
496497

498+
impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>
499+
CholeskyDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2>
500+
{
501+
}
502+
497503
impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> Matrix<T>
498504
for ArrayBase<OwnedRepr<T>, Ix2>
499505
{

0 commit comments

Comments
 (0)