Skip to content

Commit ba16c25

Browse files
Merge pull request #44 from smartcorelib/api
feat: consolidates API
2 parents a69fb3a + 810a5c4 commit ba16c25

25 files changed

+400
-98
lines changed

src/api.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
//! # Common Interfaces and API
2+
//!
3+
//! This module provides interfaces and uniform API with simple conventions
4+
//! that are used in other modules for supervised and unsupervised learning.
5+
6+
use crate::error::Failed;
7+
8+
/// An estimator for unsupervised learning, that provides method `fit` to learn from data
9+
pub trait UnsupervisedEstimator<X, P> {
10+
/// Fit a model to a training dataset, estimate model's parameters.
11+
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
12+
/// * `parameters` - hyperparameters of an algorithm
13+
fn fit(x: &X, parameters: P) -> Result<Self, Failed>
14+
where
15+
Self: Sized,
16+
P: Clone;
17+
}
18+
19+
/// An estimator for supervised learning, , that provides method `fit` to learn from data and training values
20+
pub trait SupervisedEstimator<X, Y, P> {
21+
/// Fit a model to a training dataset, estimate model's parameters.
22+
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
23+
/// * `y` - target training values of size _N_.
24+
/// * `parameters` - hyperparameters of an algorithm
25+
fn fit(x: &X, y: &Y, parameters: P) -> Result<Self, Failed>
26+
where
27+
Self: Sized,
28+
P: Clone;
29+
}
30+
31+
/// Implements method predict that estimates target value from new data
32+
pub trait Predictor<X, Y> {
33+
/// Estimate target values from new data.
34+
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
35+
fn predict(&self, x: &X) -> Result<Y, Failed>;
36+
}
37+
38+
/// Implements method transform that filters or modifies input data
39+
pub trait Transformer<X> {
40+
/// Transform data by modifying or filtering it
41+
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
42+
fn transform(&self, x: &X) -> Result<X, Failed>;
43+
}

src/base.rs

Lines changed: 0 additions & 10 deletions
This file was deleted.

src/cluster/dbscan.rs

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
//! let blobs = generator::make_blobs(100, 2, 3);
1616
//! let x = DenseMatrix::from_vec(blobs.num_samples, blobs.num_features, &blobs.data);
1717
//! // Fit the algorithm and predict cluster labels
18-
//! let labels = DBSCAN::fit(&x, Distances::euclidian(),
19-
//! DBSCANParameters::default().with_eps(3.0)).
18+
//! let labels = DBSCAN::fit(&x, DBSCANParameters::default().with_eps(3.0)).
2019
//! and_then(|dbscan| dbscan.predict(&x));
2120
//!
2221
//! println!("{:?}", labels);
@@ -33,9 +32,11 @@ use std::iter::Sum;
3332
use serde::{Deserialize, Serialize};
3433

3534
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
35+
use crate::api::{Predictor, UnsupervisedEstimator};
3636
use crate::error::Failed;
3737
use crate::linalg::{row_iter, Matrix};
38-
use crate::math::distance::Distance;
38+
use crate::math::distance::euclidian::Euclidian;
39+
use crate::math::distance::{Distance, Distances};
3940
use crate::math::num::RealNumber;
4041
use crate::tree::decision_tree_classifier::which_max;
4142

@@ -50,7 +51,11 @@ pub struct DBSCAN<T: RealNumber, D: Distance<Vec<T>, T>> {
5051

5152
#[derive(Debug, Clone)]
5253
/// DBSCAN clustering algorithm parameters
53-
pub struct DBSCANParameters<T: RealNumber> {
54+
pub struct DBSCANParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
55+
/// a function that defines a distance between each pair of point in training data.
56+
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
57+
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
58+
pub distance: D,
5459
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
5560
pub min_samples: usize,
5661
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
@@ -59,7 +64,18 @@ pub struct DBSCANParameters<T: RealNumber> {
5964
pub algorithm: KNNAlgorithmName,
6065
}
6166

62-
impl<T: RealNumber> DBSCANParameters<T> {
67+
impl<T: RealNumber, D: Distance<Vec<T>, T>> DBSCANParameters<T, D> {
68+
/// a function that defines a distance between each pair of point in training data.
69+
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
70+
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
71+
pub fn with_distance<DD: Distance<Vec<T>, T>>(self, distance: DD) -> DBSCANParameters<T, DD> {
72+
DBSCANParameters {
73+
distance,
74+
min_samples: self.min_samples,
75+
eps: self.eps,
76+
algorithm: self.algorithm,
77+
}
78+
}
6379
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
6480
pub fn with_min_samples(mut self, min_samples: usize) -> Self {
6581
self.min_samples = min_samples;
@@ -86,25 +102,41 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for DBSCAN<T, D> {
86102
}
87103
}
88104

89-
impl<T: RealNumber> Default for DBSCANParameters<T> {
105+
impl<T: RealNumber> Default for DBSCANParameters<T, Euclidian> {
90106
fn default() -> Self {
91107
DBSCANParameters {
108+
distance: Distances::euclidian(),
92109
min_samples: 5,
93110
eps: T::half(),
94111
algorithm: KNNAlgorithmName::CoverTree,
95112
}
96113
}
97114
}
98115

116+
impl<T: RealNumber + Sum, M: Matrix<T>, D: Distance<Vec<T>, T>>
117+
UnsupervisedEstimator<M, DBSCANParameters<T, D>> for DBSCAN<T, D>
118+
{
119+
fn fit(x: &M, parameters: DBSCANParameters<T, D>) -> Result<Self, Failed> {
120+
DBSCAN::fit(x, parameters)
121+
}
122+
}
123+
124+
impl<T: RealNumber, M: Matrix<T>, D: Distance<Vec<T>, T>> Predictor<M, M::RowVector>
125+
for DBSCAN<T, D>
126+
{
127+
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
128+
self.predict(x)
129+
}
130+
}
131+
99132
impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
100133
/// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features.
101134
/// * `data` - training instances to cluster
102135
/// * `k` - number of clusters
103136
/// * `parameters` - cluster parameters
104137
pub fn fit<M: Matrix<T>>(
105138
x: &M,
106-
distance: D,
107-
parameters: DBSCANParameters<T>,
139+
parameters: DBSCANParameters<T, D>,
108140
) -> Result<DBSCAN<T, D>, Failed> {
109141
if parameters.min_samples < 1 {
110142
return Err(Failed::fit(&"Invalid minPts".to_string()));
@@ -121,7 +153,9 @@ impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
121153
let n = x.shape().0;
122154
let mut y = vec![unassigned; n];
123155

124-
let algo = parameters.algorithm.fit(row_iter(x).collect(), distance)?;
156+
let algo = parameters
157+
.algorithm
158+
.fit(row_iter(x).collect(), parameters.distance)?;
125159

126160
for (i, e) in row_iter(x).enumerate() {
127161
if y[i] == unassigned {
@@ -195,7 +229,6 @@ mod tests {
195229
use super::*;
196230
use crate::linalg::naive::dense_matrix::DenseMatrix;
197231
use crate::math::distance::euclidian::Euclidian;
198-
use crate::math::distance::Distances;
199232

200233
#[test]
201234
fn fit_predict_dbscan() {
@@ -215,16 +248,7 @@ mod tests {
215248

216249
let expected_labels = vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0];
217250

218-
let dbscan = DBSCAN::fit(
219-
&x,
220-
Distances::euclidian(),
221-
DBSCANParameters {
222-
min_samples: 5,
223-
eps: 1.0,
224-
algorithm: KNNAlgorithmName::CoverTree,
225-
},
226-
)
227-
.unwrap();
251+
let dbscan = DBSCAN::fit(&x, DBSCANParameters::default().with_eps(1.0)).unwrap();
228252

229253
let predicted_labels = dbscan.predict(&x).unwrap();
230254

@@ -256,7 +280,7 @@ mod tests {
256280
&[5.2, 2.7, 3.9, 1.4],
257281
]);
258282

259-
let dbscan = DBSCAN::fit(&x, Distances::euclidian(), Default::default()).unwrap();
283+
let dbscan = DBSCAN::fit(&x, Default::default()).unwrap();
260284

261285
let deserialized_dbscan: DBSCAN<f64, Euclidian> =
262286
serde_json::from_str(&serde_json::to_string(&dbscan).unwrap()).unwrap();

src/cluster/kmeans.rs

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
//! &[5.2, 2.7, 3.9, 1.4],
4444
//! ]);
4545
//!
46-
//! let kmeans = KMeans::fit(&x, 2, Default::default()).unwrap(); // Fit to data, 2 clusters
46+
//! let kmeans = KMeans::fit(&x, KMeansParameters::default().with_k(2)).unwrap(); // Fit to data, 2 clusters
4747
//! let y_hat = kmeans.predict(&x).unwrap(); // use the same points for prediction
4848
//! ```
4949
//!
@@ -59,6 +59,7 @@ use std::iter::Sum;
5959
use serde::{Deserialize, Serialize};
6060

6161
use crate::algorithm::neighbour::bbd_tree::BBDTree;
62+
use crate::api::{Predictor, UnsupervisedEstimator};
6263
use crate::error::Failed;
6364
use crate::linalg::Matrix;
6465
use crate::math::distance::euclidian::*;
@@ -101,11 +102,18 @@ impl<T: RealNumber> PartialEq for KMeans<T> {
101102
#[derive(Debug, Clone)]
102103
/// K-Means clustering algorithm parameters
103104
pub struct KMeansParameters {
105+
/// Number of clusters.
106+
pub k: usize,
104107
/// Maximum number of iterations of the k-means algorithm for a single run.
105108
pub max_iter: usize,
106109
}
107110

108111
impl KMeansParameters {
112+
/// Number of clusters.
113+
pub fn with_k(mut self, k: usize) -> Self {
114+
self.k = k;
115+
self
116+
}
109117
/// Maximum number of iterations of the k-means algorithm for a single run.
110118
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
111119
self.max_iter = max_iter;
@@ -115,24 +123,37 @@ impl KMeansParameters {
115123

116124
impl Default for KMeansParameters {
117125
fn default() -> Self {
118-
KMeansParameters { max_iter: 100 }
126+
KMeansParameters {
127+
k: 2,
128+
max_iter: 100,
129+
}
130+
}
131+
}
132+
133+
impl<T: RealNumber + Sum, M: Matrix<T>> UnsupervisedEstimator<M, KMeansParameters> for KMeans<T> {
134+
fn fit(x: &M, parameters: KMeansParameters) -> Result<Self, Failed> {
135+
KMeans::fit(x, parameters)
136+
}
137+
}
138+
139+
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for KMeans<T> {
140+
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
141+
self.predict(x)
119142
}
120143
}
121144

122145
impl<T: RealNumber + Sum> KMeans<T> {
123146
/// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features.
124-
/// * `data` - training instances to cluster
125-
/// * `k` - number of clusters
147+
/// * `data` - training instances to cluster
126148
/// * `parameters` - cluster parameters
127-
pub fn fit<M: Matrix<T>>(
128-
data: &M,
129-
k: usize,
130-
parameters: KMeansParameters,
131-
) -> Result<KMeans<T>, Failed> {
149+
pub fn fit<M: Matrix<T>>(data: &M, parameters: KMeansParameters) -> Result<KMeans<T>, Failed> {
132150
let bbd = BBDTree::new(data);
133151

134-
if k < 2 {
135-
return Err(Failed::fit(&format!("invalid number of clusters: {}", k)));
152+
if parameters.k < 2 {
153+
return Err(Failed::fit(&format!(
154+
"invalid number of clusters: {}",
155+
parameters.k
156+
)));
136157
}
137158

138159
if parameters.max_iter == 0 {
@@ -145,9 +166,9 @@ impl<T: RealNumber + Sum> KMeans<T> {
145166
let (n, d) = data.shape();
146167

147168
let mut distortion = T::max_value();
148-
let mut y = KMeans::kmeans_plus_plus(data, k);
149-
let mut size = vec![0; k];
150-
let mut centroids = vec![vec![T::zero(); d]; k];
169+
let mut y = KMeans::kmeans_plus_plus(data, parameters.k);
170+
let mut size = vec![0; parameters.k];
171+
let mut centroids = vec![vec![T::zero(); d]; parameters.k];
151172

152173
for i in 0..n {
153174
size[y[i]] += 1;
@@ -159,16 +180,16 @@ impl<T: RealNumber + Sum> KMeans<T> {
159180
}
160181
}
161182

162-
for i in 0..k {
183+
for i in 0..parameters.k {
163184
for j in 0..d {
164185
centroids[i][j] /= T::from(size[i]).unwrap();
165186
}
166187
}
167188

168-
let mut sums = vec![vec![T::zero(); d]; k];
189+
let mut sums = vec![vec![T::zero(); d]; parameters.k];
169190
for _ in 1..=parameters.max_iter {
170191
let dist = bbd.clustering(&centroids, &mut sums, &mut size, &mut y);
171-
for i in 0..k {
192+
for i in 0..parameters.k {
172193
if size[i] > 0 {
173194
for j in 0..d {
174195
centroids[i][j] = T::from(sums[i][j]).unwrap() / T::from(size[i]).unwrap();
@@ -184,7 +205,7 @@ impl<T: RealNumber + Sum> KMeans<T> {
184205
}
185206

186207
Ok(KMeans {
187-
k,
208+
k: parameters.k,
188209
y,
189210
size,
190211
distortion,
@@ -280,10 +301,10 @@ mod tests {
280301
fn invalid_k() {
281302
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
282303

283-
assert!(KMeans::fit(&x, 0, Default::default()).is_err());
304+
assert!(KMeans::fit(&x, KMeansParameters::default().with_k(0)).is_err());
284305
assert_eq!(
285306
"Fit failed: invalid number of clusters: 1",
286-
KMeans::fit(&x, 1, Default::default())
307+
KMeans::fit(&x, KMeansParameters::default().with_k(1))
287308
.unwrap_err()
288309
.to_string()
289310
);
@@ -314,7 +335,7 @@ mod tests {
314335
&[5.2, 2.7, 3.9, 1.4],
315336
]);
316337

317-
let kmeans = KMeans::fit(&x, 2, Default::default()).unwrap();
338+
let kmeans = KMeans::fit(&x, Default::default()).unwrap();
318339

319340
let y = kmeans.predict(&x).unwrap();
320341

@@ -348,7 +369,7 @@ mod tests {
348369
&[5.2, 2.7, 3.9, 1.4],
349370
]);
350371

351-
let kmeans = KMeans::fit(&x, 2, Default::default()).unwrap();
372+
let kmeans = KMeans::fit(&x, Default::default()).unwrap();
352373

353374
let deserialized_kmeans: KMeans<f64> =
354375
serde_json::from_str(&serde_json::to_string(&kmeans).unwrap()).unwrap();

0 commit comments

Comments
 (0)