Skip to content

Commit d42fc3e

Browse files
committed
Add CholeskyImpl, InvCholeskyImpl, SolveCholeskyImpl
1 parent 608010c commit d42fc3e

File tree

1 file changed

+110
-0
lines changed

1 file changed

+110
-0
lines changed

lax/src/cholesky.rs

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,116 @@ use super::*;
22
use crate::{error::*, layout::*};
33
use cauchy::*;
44

5+
pub trait CholeskyImpl: Scalar {
6+
fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
7+
}
8+
9+
macro_rules! impl_cholesky_ {
10+
($s:ty, $trf:path) => {
11+
impl CholeskyImpl for $s {
12+
fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
13+
let (n, _) = l.size();
14+
if matches!(l, MatrixLayout::C { .. }) {
15+
square_transpose(l, a);
16+
}
17+
let mut info = 0;
18+
unsafe {
19+
$trf(uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), &n, &mut info);
20+
}
21+
info.as_lapack_result()?;
22+
if matches!(l, MatrixLayout::C { .. }) {
23+
square_transpose(l, a);
24+
}
25+
Ok(())
26+
}
27+
}
28+
};
29+
}
30+
impl_cholesky_!(c64, lapack_sys::zpotrf_);
31+
impl_cholesky_!(c32, lapack_sys::cpotrf_);
32+
impl_cholesky_!(f64, lapack_sys::dpotrf_);
33+
impl_cholesky_!(f32, lapack_sys::spotrf_);
34+
35+
pub trait InvCholeskyImpl: Scalar {
36+
fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
37+
}
38+
39+
macro_rules! impl_inv_cholesky {
40+
($s:ty, $tri:path) => {
41+
impl InvCholeskyImpl for $s {
42+
fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
43+
let (n, _) = l.size();
44+
if matches!(l, MatrixLayout::C { .. }) {
45+
square_transpose(l, a);
46+
}
47+
let mut info = 0;
48+
unsafe {
49+
$tri(uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), &l.lda(), &mut info);
50+
}
51+
info.as_lapack_result()?;
52+
if matches!(l, MatrixLayout::C { .. }) {
53+
square_transpose(l, a);
54+
}
55+
Ok(())
56+
}
57+
}
58+
};
59+
}
60+
impl_inv_cholesky!(c64, lapack_sys::zpotri_);
61+
impl_inv_cholesky!(c32, lapack_sys::cpotri_);
62+
impl_inv_cholesky!(f64, lapack_sys::dpotri_);
63+
impl_inv_cholesky!(f32, lapack_sys::spotri_);
64+
65+
pub trait SolveCholeskyImpl: Scalar {
66+
fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>;
67+
}
68+
69+
macro_rules! impl_solve_cholesky {
70+
($s:ty, $trs:path) => {
71+
impl SolveCholeskyImpl for $s {
72+
fn solve_cholesky(
73+
l: MatrixLayout,
74+
mut uplo: UPLO,
75+
a: &[Self],
76+
b: &mut [Self],
77+
) -> Result<()> {
78+
let (n, _) = l.size();
79+
let nrhs = 1;
80+
let mut info = 0;
81+
if matches!(l, MatrixLayout::C { .. }) {
82+
uplo = uplo.t();
83+
for val in b.iter_mut() {
84+
*val = val.conj();
85+
}
86+
}
87+
unsafe {
88+
$trs(
89+
uplo.as_ptr(),
90+
&n,
91+
&nrhs,
92+
AsPtr::as_ptr(a),
93+
&l.lda(),
94+
AsPtr::as_mut_ptr(b),
95+
&n,
96+
&mut info,
97+
);
98+
}
99+
info.as_lapack_result()?;
100+
if matches!(l, MatrixLayout::C { .. }) {
101+
for val in b.iter_mut() {
102+
*val = val.conj();
103+
}
104+
}
105+
Ok(())
106+
}
107+
}
108+
};
109+
}
110+
impl_solve_cholesky!(c64, lapack_sys::zpotrs_);
111+
impl_solve_cholesky!(c32, lapack_sys::cpotrs_);
112+
impl_solve_cholesky!(f64, lapack_sys::dpotrs_);
113+
impl_solve_cholesky!(f32, lapack_sys::spotrs_);
114+
5115
#[cfg_attr(doc, katexit::katexit)]
6116
/// Solve symmetric/hermite positive-definite linear equations using Cholesky decomposition
7117
///

0 commit comments

Comments
 (0)