Skip to content

Commit a69fb3a

Browse files
Merge pull request #43 from smartcorelib/kfold
Kfold
2 parents 40dfca7 + d22be7d commit a69fb3a

37 files changed

+1257
-412
lines changed

src/algorithm/neighbour/cover_tree.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//! use smartcore::algorithm::neighbour::cover_tree::*;
77
//! use smartcore::math::distance::Distance;
88
//!
9+
//! #[derive(Clone)]
910
//! struct SimpleDistance {} // Our distance function
1011
//!
1112
//! impl Distance<i32, f64> for SimpleDistance {
@@ -453,7 +454,7 @@ mod tests {
453454
use super::*;
454455
use crate::math::distance::Distances;
455456

456-
#[derive(Debug, Serialize, Deserialize)]
457+
#[derive(Debug, Serialize, Deserialize, Clone)]
457458
struct SimpleDistance {}
458459

459460
impl Distance<i32, f64> for SimpleDistance {

src/algorithm/neighbour/linear_search.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
//! use smartcore::algorithm::neighbour::linear_search::*;
66
//! use smartcore::math::distance::Distance;
77
//!
8+
//! #[derive(Clone)]
89
//! struct SimpleDistance {} // Our distance function
910
//!
1011
//! impl Distance<i32, f64> for SimpleDistance {
@@ -137,6 +138,7 @@ mod tests {
137138
use super::*;
138139
use crate::math::distance::Distances;
139140

141+
#[derive(Debug, Serialize, Deserialize, Clone)]
140142
struct SimpleDistance {}
141143

142144
impl Distance<i32, f64> for SimpleDistance {

src/base.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
//! # Common Interfaces and methods
2+
//!
3+
//! This module consolidates interfaces and uniform basic API that is used elsewhere in the code.
4+
5+
use crate::error::Failed;
6+
7+
/// Implements method predict that offers a way to estimate target value from new data
8+
pub trait Predictor<X, Y> {
9+
fn predict(&self, x: &X) -> Result<Y, Failed>;
10+
}

src/cluster/dbscan.rs

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,9 @@
1515
//! let blobs = generator::make_blobs(100, 2, 3);
1616
//! let x = DenseMatrix::from_vec(blobs.num_samples, blobs.num_features, &blobs.data);
1717
//! // Fit the algorithm and predict cluster labels
18-
//! let labels = DBSCAN::fit(&x, Distances::euclidian(), DBSCANParameters{
19-
//! min_samples: 5,
20-
//! eps: 3.0,
21-
//! algorithm: KNNAlgorithmName::CoverTree
22-
//! }).and_then(|dbscan| dbscan.predict(&x));
18+
//! let labels = DBSCAN::fit(&x, Distances::euclidian(),
19+
//! DBSCANParameters::default().with_eps(3.0)).
20+
//! and_then(|dbscan| dbscan.predict(&x));
2321
//!
2422
//! println!("{:?}", labels);
2523
//! ```
@@ -53,14 +51,32 @@ pub struct DBSCAN<T: RealNumber, D: Distance<Vec<T>, T>> {
5351
#[derive(Debug, Clone)]
5452
/// DBSCAN clustering algorithm parameters
5553
pub struct DBSCANParameters<T: RealNumber> {
56-
/// Maximum number of iterations of the k-means algorithm for a single run.
54+
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
5755
pub min_samples: usize,
58-
/// The number of samples in a neighborhood for a point to be considered as a core point.
56+
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
5957
pub eps: T,
6058
/// KNN algorithm to use.
6159
pub algorithm: KNNAlgorithmName,
6260
}
6361

62+
impl<T: RealNumber> DBSCANParameters<T> {
63+
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
64+
pub fn with_min_samples(mut self, min_samples: usize) -> Self {
65+
self.min_samples = min_samples;
66+
self
67+
}
68+
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
69+
pub fn with_eps(mut self, eps: T) -> Self {
70+
self.eps = eps;
71+
self
72+
}
73+
/// KNN algorithm to use.
74+
pub fn with_algorithm(mut self, algorithm: KNNAlgorithmName) -> Self {
75+
self.algorithm = algorithm;
76+
self
77+
}
78+
}
79+
6480
impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for DBSCAN<T, D> {
6581
fn eq(&self, other: &Self) -> bool {
6682
self.cluster_labels.len() == other.cluster_labels.len()

src/cluster/kmeans.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,14 @@ pub struct KMeansParameters {
105105
pub max_iter: usize,
106106
}
107107

108+
impl KMeansParameters {
109+
/// Maximum number of iterations of the k-means algorithm for a single run.
110+
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
111+
self.max_iter = max_iter;
112+
self
113+
}
114+
}
115+
108116
impl Default for KMeansParameters {
109117
fn default() -> Self {
110118
KMeansParameters { max_iter: 100 }

src/decomposition/pca.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,15 @@ pub struct PCAParameters {
8888
pub use_correlation_matrix: bool,
8989
}
9090

91+
impl PCAParameters {
92+
/// By default, covariance matrix is used to compute principal components.
93+
/// Enable this flag if you want to use correlation matrix instead.
94+
pub fn with_use_correlation_matrix(mut self, use_correlation_matrix: bool) -> Self {
95+
self.use_correlation_matrix = use_correlation_matrix;
96+
self
97+
}
98+
}
99+
91100
impl Default for PCAParameters {
92101
fn default() -> Self {
93102
PCAParameters {

src/ensemble/random_forest_classifier.rs

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
//!
1010
//! ```
1111
//! use smartcore::linalg::naive::dense_matrix::*;
12-
//! use smartcore::ensemble::random_forest_classifier::*;
12+
//! use smartcore::ensemble::random_forest_classifier::RandomForestClassifier;
1313
//!
1414
//! // Iris dataset
1515
//! let x = DenseMatrix::from_2d_array(&[
@@ -51,6 +51,7 @@ use std::fmt::Debug;
5151
use rand::Rng;
5252
use serde::{Deserialize, Serialize};
5353

54+
use crate::base::Predictor;
5455
use crate::error::Failed;
5556
use crate::linalg::Matrix;
5657
use crate::math::num::RealNumber;
@@ -84,6 +85,39 @@ pub struct RandomForestClassifier<T: RealNumber> {
8485
classes: Vec<T>,
8586
}
8687

88+
impl RandomForestClassifierParameters {
89+
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
90+
pub fn with_criterion(mut self, criterion: SplitCriterion) -> Self {
91+
self.criterion = criterion;
92+
self
93+
}
94+
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
95+
pub fn with_max_depth(mut self, max_depth: u16) -> Self {
96+
self.max_depth = Some(max_depth);
97+
self
98+
}
99+
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
100+
pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
101+
self.min_samples_leaf = min_samples_leaf;
102+
self
103+
}
104+
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
105+
pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
106+
self.min_samples_split = min_samples_split;
107+
self
108+
}
109+
/// The number of trees in the forest.
110+
pub fn with_n_trees(mut self, n_trees: u16) -> Self {
111+
self.n_trees = n_trees;
112+
self
113+
}
114+
/// Number of random sample of predictors to use as split candidates.
115+
pub fn with_m(mut self, m: usize) -> Self {
116+
self.m = Some(m);
117+
self
118+
}
119+
}
120+
87121
impl<T: RealNumber> PartialEq for RandomForestClassifier<T> {
88122
fn eq(&self, other: &Self) -> bool {
89123
if self.classes.len() != other.classes.len() || self.trees.len() != other.trees.len() {
@@ -117,6 +151,12 @@ impl Default for RandomForestClassifierParameters {
117151
}
118152
}
119153

154+
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RandomForestClassifier<T> {
155+
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
156+
self.predict(x)
157+
}
158+
}
159+
120160
impl<T: RealNumber> RandomForestClassifier<T> {
121161
/// Build a forest of trees from the training set.
122162
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.

src/ensemble/random_forest_regressor.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ use std::fmt::Debug;
4949
use rand::Rng;
5050
use serde::{Deserialize, Serialize};
5151

52+
use crate::base::Predictor;
5253
use crate::error::Failed;
5354
use crate::linalg::Matrix;
5455
use crate::math::num::RealNumber;
@@ -79,6 +80,34 @@ pub struct RandomForestRegressor<T: RealNumber> {
7980
trees: Vec<DecisionTreeRegressor<T>>,
8081
}
8182

83+
impl RandomForestRegressorParameters {
84+
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
85+
pub fn with_max_depth(mut self, max_depth: u16) -> Self {
86+
self.max_depth = Some(max_depth);
87+
self
88+
}
89+
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
90+
pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
91+
self.min_samples_leaf = min_samples_leaf;
92+
self
93+
}
94+
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
95+
pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
96+
self.min_samples_split = min_samples_split;
97+
self
98+
}
99+
/// The number of trees in the forest.
100+
pub fn with_n_trees(mut self, n_trees: usize) -> Self {
101+
self.n_trees = n_trees;
102+
self
103+
}
104+
/// Number of random sample of predictors to use as split candidates.
105+
pub fn with_m(mut self, m: usize) -> Self {
106+
self.m = Some(m);
107+
self
108+
}
109+
}
110+
82111
impl Default for RandomForestRegressorParameters {
83112
fn default() -> Self {
84113
RandomForestRegressorParameters {
@@ -106,6 +135,12 @@ impl<T: RealNumber> PartialEq for RandomForestRegressor<T> {
106135
}
107136
}
108137

138+
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RandomForestRegressor<T> {
139+
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
140+
self.predict(x)
141+
}
142+
}
143+
109144
impl<T: RealNumber> RandomForestRegressor<T> {
110145
/// Build a forest of trees from the training set.
111146
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.

src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,15 @@
6363
//! let y = vec![2., 2., 2., 3., 3.];
6464
//!
6565
//! // Train classifier
66-
//! let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
66+
//! let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
6767
//!
6868
//! // Predict classes
6969
//! let y_hat = knn.predict(&x).unwrap();
7070
//! ```
7171
7272
/// Various algorithms and helper methods that are used elsewhere in SmartCore
7373
pub mod algorithm;
74+
pub(crate) mod base;
7475
/// Algorithms for clustering of unlabeled data
7576
pub mod cluster;
7677
/// Various datasets

src/linalg/mod.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,19 @@ pub trait BaseVector<T: RealNumber>: Clone + Debug {
274274

275275
/// Copies content of `other` vector.
276276
fn copy_from(&mut self, other: &Self);
277+
278+
/// Take elements from an array.
279+
fn take(&self, index: &[usize]) -> Self {
280+
let n = index.len();
281+
282+
let mut result = Self::zeros(n);
283+
284+
for (i, idx) in index.iter().enumerate() {
285+
result.set(i, self.get(*idx));
286+
}
287+
288+
result
289+
}
277290
}
278291

279292
/// Generic matrix type.
@@ -611,6 +624,32 @@ pub trait BaseMatrix<T: RealNumber>: Clone + Debug {
611624

612625
/// Calculates the covariance matrix
613626
fn cov(&self) -> Self;
627+
628+
/// Take elements from an array along an axis.
629+
fn take(&self, index: &[usize], axis: u8) -> Self {
630+
let (n, p) = self.shape();
631+
632+
let k = match axis {
633+
0 => p,
634+
_ => n,
635+
};
636+
637+
let mut result = match axis {
638+
0 => Self::zeros(index.len(), p),
639+
_ => Self::zeros(n, index.len()),
640+
};
641+
642+
for (i, idx) in index.iter().enumerate() {
643+
for j in 0..k {
644+
match axis {
645+
0 => result.set(i, j, self.get(*idx, j)),
646+
_ => result.set(j, i, self.get(j, *idx)),
647+
};
648+
}
649+
}
650+
651+
result
652+
}
614653
}
615654

616655
/// Generic matrix with additional mixins like various factorization methods.
@@ -662,6 +701,8 @@ impl<'a, T: RealNumber, M: BaseMatrix<T>> Iterator for RowIter<'a, T, M> {
662701

663702
#[cfg(test)]
664703
mod tests {
704+
use crate::linalg::naive::dense_matrix::DenseMatrix;
705+
use crate::linalg::BaseMatrix;
665706
use crate::linalg::BaseVector;
666707

667708
#[test]
@@ -684,4 +725,35 @@ mod tests {
684725

685726
assert!((m.var() - 1.25f64).abs() < std::f64::EPSILON);
686727
}
728+
729+
#[test]
730+
fn vec_take() {
731+
let m = vec![1., 2., 3., 4., 5.];
732+
733+
assert_eq!(m.take(&vec!(0, 0, 4, 4)), vec![1., 1., 5., 5.]);
734+
}
735+
736+
#[test]
737+
fn take() {
738+
let m = DenseMatrix::from_2d_array(&[
739+
&[1.0, 2.0],
740+
&[3.0, 4.0],
741+
&[5.0, 6.0],
742+
&[7.0, 8.0],
743+
&[9.0, 10.0],
744+
]);
745+
746+
let expected_0 = DenseMatrix::from_2d_array(&[&[3.0, 4.0], &[3.0, 4.0], &[7.0, 8.0]]);
747+
748+
let expected_1 = DenseMatrix::from_2d_array(&[
749+
&[2.0, 1.0],
750+
&[4.0, 3.0],
751+
&[6.0, 5.0],
752+
&[8.0, 7.0],
753+
&[10.0, 9.0],
754+
]);
755+
756+
assert_eq!(m.take(&vec!(1, 1, 3), 0), expected_0);
757+
assert_eq!(m.take(&vec!(1, 0), 1), expected_1);
758+
}
687759
}

0 commit comments

Comments
 (0)