Skip to content

Commit a2be9e1

Browse files
Volodymyr OrlovVolodymyr Orlov
authored andcommitted
feat: + cross_validate, trait Predictor, refactoring
1 parent 40dfca7 commit a2be9e1

34 files changed

+976
-368
lines changed

src/algorithm/neighbour/cover_tree.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//! use smartcore::algorithm::neighbour::cover_tree::*;
77
//! use smartcore::math::distance::Distance;
88
//!
9+
//! #[derive(Clone)]
910
//! struct SimpleDistance {} // Our distance function
1011
//!
1112
//! impl Distance<i32, f64> for SimpleDistance {
@@ -453,7 +454,7 @@ mod tests {
453454
use super::*;
454455
use crate::math::distance::Distances;
455456

456-
#[derive(Debug, Serialize, Deserialize)]
457+
#[derive(Debug, Serialize, Deserialize, Clone)]
457458
struct SimpleDistance {}
458459

459460
impl Distance<i32, f64> for SimpleDistance {

src/algorithm/neighbour/linear_search.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
//! use smartcore::algorithm::neighbour::linear_search::*;
66
//! use smartcore::math::distance::Distance;
77
//!
8+
//! #[derive(Clone)]
89
//! struct SimpleDistance {} // Our distance function
910
//!
1011
//! impl Distance<i32, f64> for SimpleDistance {
@@ -137,6 +138,7 @@ mod tests {
137138
use super::*;
138139
use crate::math::distance::Distances;
139140

141+
#[derive(Debug, Serialize, Deserialize, Clone)]
140142
struct SimpleDistance {}
141143

142144
impl Distance<i32, f64> for SimpleDistance {

src/base.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
//! # Common Interfaces and methods
2+
//!
3+
//! This module consolidates interfaces and uniform basic API that is used elsewhere in the code.
4+
5+
use crate::error::Failed;
6+
7+
/// Implements method predict that offers a way to estimate target value from new data
8+
pub trait Predictor<X, Y> {
9+
fn predict(&self, x: &X) -> Result<Y, Failed>;
10+
}

src/ensemble/random_forest_classifier.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
//!
1010
//! ```
1111
//! use smartcore::linalg::naive::dense_matrix::*;
12-
//! use smartcore::ensemble::random_forest_classifier::*;
12+
//! use smartcore::ensemble::random_forest_classifier::RandomForestClassifier;
1313
//!
1414
//! // Iris dataset
1515
//! let x = DenseMatrix::from_2d_array(&[
@@ -51,6 +51,7 @@ use std::fmt::Debug;
5151
use rand::Rng;
5252
use serde::{Deserialize, Serialize};
5353

54+
use crate::base::Predictor;
5455
use crate::error::Failed;
5556
use crate::linalg::Matrix;
5657
use crate::math::num::RealNumber;
@@ -117,6 +118,12 @@ impl Default for RandomForestClassifierParameters {
117118
}
118119
}
119120

121+
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RandomForestClassifier<T> {
122+
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
123+
self.predict(x)
124+
}
125+
}
126+
120127
impl<T: RealNumber> RandomForestClassifier<T> {
121128
/// Build a forest of trees from the training set.
122129
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.

src/ensemble/random_forest_regressor.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ use std::fmt::Debug;
4949
use rand::Rng;
5050
use serde::{Deserialize, Serialize};
5151

52+
use crate::base::Predictor;
5253
use crate::error::Failed;
5354
use crate::linalg::Matrix;
5455
use crate::math::num::RealNumber;
@@ -106,6 +107,12 @@ impl<T: RealNumber> PartialEq for RandomForestRegressor<T> {
106107
}
107108
}
108109

110+
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RandomForestRegressor<T> {
111+
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
112+
self.predict(x)
113+
}
114+
}
115+
109116
impl<T: RealNumber> RandomForestRegressor<T> {
110117
/// Build a forest of trees from the training set.
111118
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.

src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,15 @@
6363
//! let y = vec![2., 2., 2., 3., 3.];
6464
//!
6565
//! // Train classifier
66-
//! let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
66+
//! let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
6767
//!
6868
//! // Predict classes
6969
//! let y_hat = knn.predict(&x).unwrap();
7070
//! ```
7171
7272
/// Various algorithms and helper methods that are used elsewhere in SmartCore
7373
pub mod algorithm;
74+
pub(crate) mod base;
7475
/// Algorithms for clustering of unlabeled data
7576
pub mod cluster;
7677
/// Various datasets

src/linalg/mod.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,19 @@ pub trait BaseVector<T: RealNumber>: Clone + Debug {
274274

275275
/// Copies content of `other` vector.
276276
fn copy_from(&mut self, other: &Self);
277+
278+
/// Take elements from an array.
279+
fn take(&self, index: &[usize]) -> Self {
280+
let n = index.len();
281+
282+
let mut result = Self::zeros(n);
283+
284+
for i in 0..n {
285+
result.set(i, self.get(index[i]));
286+
}
287+
288+
result
289+
}
277290
}
278291

279292
/// Generic matrix type.
@@ -611,6 +624,32 @@ pub trait BaseMatrix<T: RealNumber>: Clone + Debug {
611624

612625
/// Calculates the covariance matrix
613626
fn cov(&self) -> Self;
627+
628+
/// Take elements from an array along an axis.
629+
fn take(&self, index: &[usize], axis: u8) -> Self {
630+
let (n, p) = self.shape();
631+
632+
let k = match axis {
633+
0 => p,
634+
_ => n,
635+
};
636+
637+
let mut result = match axis {
638+
0 => Self::zeros(index.len(), p),
639+
_ => Self::zeros(n, index.len()),
640+
};
641+
642+
for i in 0..index.len() {
643+
for j in 0..k {
644+
match axis {
645+
0 => result.set(i, j, self.get(index[i], j)),
646+
_ => result.set(j, i, self.get(j, index[i])),
647+
};
648+
}
649+
}
650+
651+
result
652+
}
614653
}
615654

616655
/// Generic matrix with additional mixins like various factorization methods.
@@ -662,6 +701,8 @@ impl<'a, T: RealNumber, M: BaseMatrix<T>> Iterator for RowIter<'a, T, M> {
662701

663702
#[cfg(test)]
664703
mod tests {
704+
use crate::linalg::naive::dense_matrix::DenseMatrix;
705+
use crate::linalg::BaseMatrix;
665706
use crate::linalg::BaseVector;
666707

667708
#[test]
@@ -684,4 +725,35 @@ mod tests {
684725

685726
assert!((m.var() - 1.25f64).abs() < std::f64::EPSILON);
686727
}
728+
729+
#[test]
730+
fn vec_take() {
731+
let m = vec![1., 2., 3., 4., 5.];
732+
733+
assert_eq!(m.take(&vec!(0, 0, 4, 4)), vec![1., 1., 5., 5.]);
734+
}
735+
736+
#[test]
737+
fn take() {
738+
let m = DenseMatrix::from_2d_array(&[
739+
&[1.0, 2.0],
740+
&[3.0, 4.0],
741+
&[5.0, 6.0],
742+
&[7.0, 8.0],
743+
&[9.0, 10.0],
744+
]);
745+
746+
let expected_0 = DenseMatrix::from_2d_array(&[&[3.0, 4.0], &[3.0, 4.0], &[7.0, 8.0]]);
747+
748+
let expected_1 = DenseMatrix::from_2d_array(&[
749+
&[2.0, 1.0],
750+
&[4.0, 3.0],
751+
&[6.0, 5.0],
752+
&[8.0, 7.0],
753+
&[10.0, 9.0],
754+
]);
755+
756+
assert_eq!(m.take(&vec!(1, 1, 3), 0), expected_0);
757+
assert_eq!(m.take(&vec!(1, 0), 1), expected_1);
758+
}
687759
}

src/linalg/ndarray_bindings.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.
3737
//! ]);
3838
//!
39-
//! let lr = LogisticRegression::fit(&x, &y).unwrap();
39+
//! let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
4040
//! let y_hat = lr.predict(&x).unwrap();
4141
//! ```
4242
use std::iter::Sum;
@@ -917,7 +917,7 @@ mod tests {
917917
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
918918
]);
919919

920-
let lr = LogisticRegression::fit(&x, &y).unwrap();
920+
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
921921

922922
let y_hat = lr.predict(&x).unwrap();
923923

src/linear/elastic_net.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ use std::fmt::Debug;
5858

5959
use serde::{Deserialize, Serialize};
6060

61+
use crate::base::Predictor;
6162
use crate::error::Failed;
6263
use crate::linalg::BaseVector;
6364
use crate::linalg::Matrix;
@@ -66,7 +67,7 @@ use crate::math::num::RealNumber;
6667
use crate::linear::lasso_optimizer::InteriorPointOptimizer;
6768

6869
/// Elastic net parameters
69-
#[derive(Serialize, Deserialize, Debug)]
70+
#[derive(Serialize, Deserialize, Debug, Clone)]
7071
pub struct ElasticNetParameters<T: RealNumber> {
7172
/// Regularization parameter.
7273
pub alpha: T,
@@ -108,6 +109,12 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for ElasticNet<T, M> {
108109
}
109110
}
110111

112+
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for ElasticNet<T, M> {
113+
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
114+
self.predict(x)
115+
}
116+
}
117+
111118
impl<T: RealNumber, M: Matrix<T>> ElasticNet<T, M> {
112119
/// Fits elastic net regression to your data.
113120
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.

src/linear/lasso.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@ use std::fmt::Debug;
2626

2727
use serde::{Deserialize, Serialize};
2828

29+
use crate::base::Predictor;
2930
use crate::error::Failed;
3031
use crate::linalg::BaseVector;
3132
use crate::linalg::Matrix;
3233
use crate::linear::lasso_optimizer::InteriorPointOptimizer;
3334
use crate::math::num::RealNumber;
3435

3536
/// Lasso regression parameters
36-
#[derive(Serialize, Deserialize, Debug)]
37+
#[derive(Serialize, Deserialize, Debug, Clone)]
3738
pub struct LassoParameters<T: RealNumber> {
3839
/// Controls the strength of the penalty to the loss function.
3940
pub alpha: T,
@@ -71,6 +72,12 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for Lasso<T, M> {
7172
}
7273
}
7374

75+
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for Lasso<T, M> {
76+
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
77+
self.predict(x)
78+
}
79+
}
80+
7481
impl<T: RealNumber, M: Matrix<T>> Lasso<T, M> {
7582
/// Fits Lasso regression to your data.
7683
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.

0 commit comments

Comments
 (0)