Skip to content

Commit f0b348d

Browse files
authored
feat: BernoulliNB (#31)
* feat: BernoulliNB * Move preprocessing to a trait in linalg/stats.rs
1 parent 4720a3a commit f0b348d

File tree

7 files changed

+367
-4
lines changed

7 files changed

+367
-4
lines changed

src/linalg/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ use evd::EVDDecomposableMatrix;
6363
use high_order::HighOrderOperations;
6464
use lu::LUDecomposableMatrix;
6565
use qr::QRDecomposableMatrix;
66-
use stats::MatrixStats;
66+
use stats::{MatrixPreprocessing, MatrixStats};
6767
use svd::SVDDecomposableMatrix;
6868

6969
/// Column or row vector
@@ -619,6 +619,7 @@ pub trait Matrix<T: RealNumber>:
619619
+ LUDecomposableMatrix<T>
620620
+ CholeskyDecomposableMatrix<T>
621621
+ MatrixStats<T>
622+
+ MatrixPreprocessing<T>
622623
+ HighOrderOperations<T>
623624
+ PartialEq
624625
+ Display

src/linalg/naive/dense_matrix.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::linalg::evd::EVDDecomposableMatrix;
1212
use crate::linalg::high_order::HighOrderOperations;
1313
use crate::linalg::lu::LUDecomposableMatrix;
1414
use crate::linalg::qr::QRDecomposableMatrix;
15-
use crate::linalg::stats::MatrixStats;
15+
use crate::linalg::stats::{MatrixPreprocessing, MatrixStats};
1616
use crate::linalg::svd::SVDDecomposableMatrix;
1717
use crate::linalg::Matrix;
1818
pub use crate::linalg::{BaseMatrix, BaseVector};
@@ -478,6 +478,7 @@ impl<T: RealNumber> HighOrderOperations<T> for DenseMatrix<T> {
478478
}
479479

480480
impl<T: RealNumber> MatrixStats<T> for DenseMatrix<T> {}
481+
impl<T: RealNumber> MatrixPreprocessing<T> for DenseMatrix<T> {}
481482

482483
impl<T: RealNumber> Matrix<T> for DenseMatrix<T> {}
483484

src/linalg/nalgebra_bindings.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ use crate::linalg::evd::EVDDecomposableMatrix;
4747
use crate::linalg::high_order::HighOrderOperations;
4848
use crate::linalg::lu::LUDecomposableMatrix;
4949
use crate::linalg::qr::QRDecomposableMatrix;
50-
use crate::linalg::stats::MatrixStats;
50+
use crate::linalg::stats::{MatrixPreprocessing, MatrixStats};
5151
use crate::linalg::svd::SVDDecomposableMatrix;
5252
use crate::linalg::Matrix as SmartCoreMatrix;
5353
use crate::linalg::{BaseMatrix, BaseVector};
@@ -554,6 +554,11 @@ impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Su
554554
{
555555
}
556556

557+
impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
558+
MatrixPreprocessing<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
559+
{
560+
}
561+
557562
impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
558563
HighOrderOperations<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
559564
{

src/linalg/ndarray_bindings.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ use crate::linalg::evd::EVDDecomposableMatrix;
5454
use crate::linalg::high_order::HighOrderOperations;
5555
use crate::linalg::lu::LUDecomposableMatrix;
5656
use crate::linalg::qr::QRDecomposableMatrix;
57-
use crate::linalg::stats::MatrixStats;
57+
use crate::linalg::stats::{MatrixPreprocessing, MatrixStats};
5858
use crate::linalg::svd::SVDDecomposableMatrix;
5959
use crate::linalg::Matrix;
6060
use crate::linalg::{BaseMatrix, BaseVector};
@@ -503,6 +503,11 @@ impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssi
503503
{
504504
}
505505

506+
impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>
507+
MatrixPreprocessing<T> for ArrayBase<OwnedRepr<T>, Ix2>
508+
{
509+
}
510+
506511
impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>
507512
HighOrderOperations<T> for ArrayBase<OwnedRepr<T>, Ix2>
508513
{

src/linalg/stats.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,47 @@ pub trait MatrixStats<T: RealNumber>: BaseMatrix<T> {
104104
}
105105
}
106106

107+
/// Defines baseline implementations for various matrix processing functions
108+
pub trait MatrixPreprocessing<T: RealNumber>: BaseMatrix<T> {
109+
/// Each element of the matrix greater than the threshold becomes 1, while values less than or equal to the threshold become 0
110+
/// ```
111+
/// use smartcore::linalg::naive::dense_matrix::*;
112+
/// use crate::smartcore::linalg::stats::MatrixPreprocessing;
113+
/// let mut a = DenseMatrix::from_array(2, 3, &[0., 2., 3., -5., -6., -7.]);
114+
/// let expected = DenseMatrix::from_array(2, 3, &[0., 1., 1., 0., 0., 0.]);
115+
/// a.binarize_mut(0.);
116+
///
117+
/// assert_eq!(a, expected);
118+
/// ```
119+
120+
fn binarize_mut(&mut self, threshold: T) {
121+
let (nrows, ncols) = self.shape();
122+
for row in 0..nrows {
123+
for col in 0..ncols {
124+
if self.get(row, col) > threshold {
125+
self.set(row, col, T::one());
126+
} else {
127+
self.set(row, col, T::zero());
128+
}
129+
}
130+
}
131+
}
132+
/// Returns new matrix where elements are binarized according to a given threshold.
133+
/// ```
134+
/// use smartcore::linalg::naive::dense_matrix::*;
135+
/// use crate::smartcore::linalg::stats::MatrixPreprocessing;
136+
/// let a = DenseMatrix::from_array(2, 3, &[0., 2., 3., -5., -6., -7.]);
137+
/// let expected = DenseMatrix::from_array(2, 3, &[0., 1., 1., 0., 0., 0.]);
138+
///
139+
/// assert_eq!(a.binarize(0.), expected);
140+
/// ```
141+
fn binarize(&self, threshold: T) -> Self {
142+
let mut m = self.clone();
143+
m.binarize_mut(threshold);
144+
m
145+
}
146+
}
147+
107148
#[cfg(test)]
108149
mod tests {
109150
use super::*;

0 commit comments

Comments
 (0)