Skip to content

Commit a9db970

Browse files
Volodymyr OrlovVolodymyr Orlov
authored andcommitted
feat: refactoring, adds Result to most public API
1 parent 4921ae7 commit a9db970

24 files changed

+383
-292
lines changed

src/algorithm/neighbour/cover_tree.rs

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
//!
1717
//! let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; // data points
1818
//!
19-
//! let mut tree = CoverTree::new(data, SimpleDistance {});
19+
//! let mut tree = CoverTree::new(data, SimpleDistance {}).unwrap();
2020
//!
2121
//! tree.find(&5, 3); // find 3 knn points from 5
2222
//!
@@ -26,6 +26,7 @@ use std::fmt::Debug;
2626
use serde::{Deserialize, Serialize};
2727

2828
use crate::algorithm::sort::heap_select::HeapSelection;
29+
use crate::error::{Failed, FailedError};
2930
use crate::math::distance::Distance;
3031
use crate::math::num::RealNumber;
3132

@@ -73,7 +74,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
7374
/// Construct a cover tree.
7475
/// * `data` - vector of data points to search for.
7576
/// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface.
76-
pub fn new(data: Vec<T>, distance: D) -> CoverTree<T, F, D> {
77+
pub fn new(data: Vec<T>, distance: D) -> Result<CoverTree<T, F, D>, Failed> {
7778
let base = F::from_f64(1.3).unwrap();
7879
let root = Node {
7980
idx: 0,
@@ -93,19 +94,22 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
9394

9495
tree.build_cover_tree();
9596

96-
tree
97+
Ok(tree)
9798
}
9899

99100
/// Find k nearest neighbors of `p`
100101
/// * `p` - look for k nearest points to `p`
101102
/// * `k` - the number of nearest neighbors to return
102-
pub fn find(&self, p: &T, k: usize) -> Vec<(usize, F)> {
103+
pub fn find(&self, p: &T, k: usize) -> Result<Vec<(usize, F)>, Failed> {
103104
if k <= 0 {
104-
panic!("k should be > 0");
105+
return Err(Failed::because(FailedError::FindFailed, "k should be > 0"));
105106
}
106107

107108
if k > self.data.len() {
108-
panic!("k is > than the dataset size");
109+
return Err(Failed::because(
110+
FailedError::FindFailed,
111+
"k is > than the dataset size",
112+
));
109113
}
110114

111115
let e = self.get_data_value(self.root.idx);
@@ -171,7 +175,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
171175
}
172176
}
173177

174-
neighbors.into_iter().take(k).collect()
178+
Ok(neighbors.into_iter().take(k).collect())
175179
}
176180

177181
fn new_leaf(&self, idx: usize) -> Node<F> {
@@ -407,9 +411,9 @@ mod tests {
407411
fn cover_tree_test() {
408412
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
409413

410-
let tree = CoverTree::new(data, SimpleDistance {});
414+
let tree = CoverTree::new(data, SimpleDistance {}).unwrap();
411415

412-
let mut knn = tree.find(&5, 3);
416+
let mut knn = tree.find(&5, 3).unwrap();
413417
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
414418
let knn: Vec<usize> = knn.iter().map(|v| v.0).collect();
415419
assert_eq!(vec!(3, 4, 5), knn);
@@ -425,9 +429,9 @@ mod tests {
425429
vec![9., 10.],
426430
];
427431

428-
let tree = CoverTree::new(data, Distances::euclidian());
432+
let tree = CoverTree::new(data, Distances::euclidian()).unwrap();
429433

430-
let mut knn = tree.find(&vec![1., 2.], 3);
434+
let mut knn = tree.find(&vec![1., 2.], 3).unwrap();
431435
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
432436
let knn: Vec<usize> = knn.iter().map(|v| v.0).collect();
433437

@@ -438,7 +442,7 @@ mod tests {
438442
fn serde() {
439443
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
440444

441-
let tree = CoverTree::new(data, SimpleDistance {});
445+
let tree = CoverTree::new(data, SimpleDistance {}).unwrap();
442446

443447
let deserialized_tree: CoverTree<i32, f64, SimpleDistance> =
444448
serde_json::from_str(&serde_json::to_string(&tree).unwrap()).unwrap();

src/algorithm/neighbour/linear_search.rs

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
//!
1616
//! let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; // data points
1717
//!
18-
//! let knn = LinearKNNSearch::new(data, SimpleDistance {});
18+
//! let knn = LinearKNNSearch::new(data, SimpleDistance {}).unwrap();
1919
//!
2020
//! knn.find(&5, 3); // find 3 knn points from 5
2121
//!
@@ -26,6 +26,7 @@ use std::cmp::{Ordering, PartialOrd};
2626
use std::marker::PhantomData;
2727

2828
use crate::algorithm::sort::heap_select::HeapSelection;
29+
use crate::error::Failed;
2930
use crate::math::distance::Distance;
3031
use crate::math::num::RealNumber;
3132

@@ -41,18 +42,18 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
4142
/// Initializes algorithm.
4243
/// * `data` - vector of data points to search for.
4344
/// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface.
44-
pub fn new(data: Vec<T>, distance: D) -> LinearKNNSearch<T, F, D> {
45-
LinearKNNSearch {
45+
pub fn new(data: Vec<T>, distance: D) -> Result<LinearKNNSearch<T, F, D>, Failed> {
46+
Ok(LinearKNNSearch {
4647
data: data,
4748
distance: distance,
4849
f: PhantomData,
49-
}
50+
})
5051
}
5152

5253
/// Find k nearest neighbors
5354
/// * `from` - look for k nearest points to `from`
5455
/// * `k` - the number of nearest neighbors to return
55-
pub fn find(&self, from: &T, k: usize) -> Vec<(usize, F)> {
56+
pub fn find(&self, from: &T, k: usize) -> Result<Vec<(usize, F)>, Failed> {
5657
if k < 1 || k > self.data.len() {
5758
panic!("k should be >= 1 and <= length(data)");
5859
}
@@ -76,10 +77,11 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
7677
}
7778
}
7879

79-
heap.get()
80+
Ok(heap
81+
.get()
8082
.into_iter()
8183
.flat_map(|x| x.index.map(|i| (i, x.distance)))
82-
.collect()
84+
.collect())
8385
}
8486
}
8587

@@ -120,9 +122,14 @@ mod tests {
120122
fn knn_find() {
121123
let data1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
122124

123-
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {});
125+
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {}).unwrap();
124126

125-
let mut found_idxs1: Vec<usize> = algorithm1.find(&2, 3).iter().map(|v| v.0).collect();
127+
let mut found_idxs1: Vec<usize> = algorithm1
128+
.find(&2, 3)
129+
.unwrap()
130+
.iter()
131+
.map(|v| v.0)
132+
.collect();
126133
found_idxs1.sort();
127134

128135
assert_eq!(vec!(0, 1, 2), found_idxs1);
@@ -135,10 +142,11 @@ mod tests {
135142
vec![5., 5.],
136143
];
137144

138-
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian());
145+
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian()).unwrap();
139146

140147
let mut found_idxs2: Vec<usize> = algorithm2
141148
.find(&vec![3., 3.], 3)
149+
.unwrap()
142150
.iter()
143151
.map(|v| v.0)
144152
.collect();

src/cluster/kmeans.rs

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ use std::iter::Sum;
6161
use serde::{Deserialize, Serialize};
6262

6363
use crate::algorithm::neighbour::bbd_tree::BBDTree;
64-
use crate::error::{FitFailedError, PredictFailedError};
64+
use crate::error::Failed;
6565
use crate::linalg::Matrix;
6666
use crate::math::distance::euclidian::*;
6767
use crate::math::num::RealNumber;
@@ -122,19 +122,16 @@ impl<T: RealNumber + Sum> KMeans<T> {
122122
data: &M,
123123
k: usize,
124124
parameters: KMeansParameters,
125-
) -> Result<KMeans<T>, FitFailedError> {
125+
) -> Result<KMeans<T>, Failed> {
126126
let bbd = BBDTree::new(data);
127127

128128
if k < 2 {
129-
return Err(FitFailedError::new(&format!(
130-
"Invalid number of clusters: {}",
131-
k
132-
)));
129+
return Err(Failed::fit(&format!("invalid number of clusters: {}", k)));
133130
}
134131

135132
if parameters.max_iter <= 0 {
136-
return Err(FitFailedError::new(&format!(
137-
"Invalid maximum number of iterations: {}",
133+
return Err(Failed::fit(&format!(
134+
"invalid maximum number of iterations: {}",
138135
parameters.max_iter
139136
)));
140137
}
@@ -191,7 +188,7 @@ impl<T: RealNumber + Sum> KMeans<T> {
191188

192189
/// Predict clusters for `x`
193190
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
194-
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, PredictFailedError> {
191+
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
195192
let (n, _) = x.shape();
196193
let mut result = M::zeros(1, n);
197194

@@ -274,11 +271,9 @@ mod tests {
274271
fn invalid_k() {
275272
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
276273

277-
println!("{:?}", KMeans::fit(&x, 0, Default::default()));
278-
279274
assert!(KMeans::fit(&x, 0, Default::default()).is_err());
280275
assert_eq!(
281-
"Invalid number of clusters: 1",
276+
"Fit failed: invalid number of clusters: 1",
282277
KMeans::fit(&x, 1, Default::default())
283278
.unwrap_err()
284279
.to_string()

src/decomposition/pca.rs

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@
3737
//! &[5.2, 2.7, 3.9, 1.4],
3838
//! ]);
3939
//!
40-
//! let pca = PCA::new(&iris, 2, Default::default()); // Reduce number of features to 2
40+
//! let pca = PCA::fit(&iris, 2, Default::default()).unwrap(); // Reduce number of features to 2
4141
//!
42-
//! let iris_reduced = pca.transform(&iris);
42+
//! let iris_reduced = pca.transform(&iris).unwrap();
4343
//!
4444
//! ```
4545
//!
@@ -49,6 +49,7 @@ use std::fmt::Debug;
4949

5050
use serde::{Deserialize, Serialize};
5151

52+
use crate::error::Failed;
5253
use crate::linalg::Matrix;
5354
use crate::math::num::RealNumber;
5455

@@ -100,7 +101,11 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
100101
/// * `data` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
101102
/// * `n_components` - number of components to keep.
102103
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
103-
pub fn new(data: &M, n_components: usize, parameters: PCAParameters) -> PCA<T, M> {
104+
pub fn fit(
105+
data: &M,
106+
n_components: usize,
107+
parameters: PCAParameters,
108+
) -> Result<PCA<T, M>, Failed> {
104109
let (m, n) = data.shape();
105110

106111
let mu = data.column_mean();
@@ -117,7 +122,7 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
117122
let mut eigenvectors;
118123

119124
if m > n && !parameters.use_correlation_matrix {
120-
let svd = x.svd();
125+
let svd = x.svd()?;
121126
eigenvalues = svd.s;
122127
for i in 0..eigenvalues.len() {
123128
eigenvalues[i] = eigenvalues[i] * eigenvalues[i];
@@ -155,7 +160,7 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
155160
}
156161
}
157162

158-
let evd = cov.evd(true);
163+
let evd = cov.evd(true)?;
159164

160165
eigenvalues = evd.d;
161166

@@ -167,7 +172,7 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
167172
}
168173
}
169174
} else {
170-
let evd = cov.evd(true);
175+
let evd = cov.evd(true)?;
171176

172177
eigenvalues = evd.d;
173178

@@ -189,26 +194,26 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
189194
}
190195
}
191196

192-
PCA {
197+
Ok(PCA {
193198
eigenvectors: eigenvectors,
194199
eigenvalues: eigenvalues,
195200
projection: projection.transpose(),
196201
mu: mu,
197202
pmu: pmu,
198-
}
203+
})
199204
}
200205

201206
/// Run dimensionality reduction for `x`
202207
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
203-
pub fn transform(&self, x: &M) -> M {
208+
pub fn transform(&self, x: &M) -> Result<M, Failed> {
204209
let (nrows, ncols) = x.shape();
205210
let (_, n_components) = self.projection.shape();
206211
if ncols != self.mu.len() {
207-
panic!(
212+
return Err(Failed::transform(&format!(
208213
"Invalid input vector size: {}, expected: {}",
209214
ncols,
210215
self.mu.len()
211-
);
216+
)));
212217
}
213218

214219
let mut x_transformed = x.matmul(&self.projection);
@@ -217,7 +222,7 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
217222
x_transformed.sub_element_mut(r, c, self.pmu[c]);
218223
}
219224
}
220-
x_transformed
225+
Ok(x_transformed)
221226
}
222227
}
223228

@@ -372,7 +377,7 @@ mod tests {
372377
302.04806302399646,
373378
];
374379

375-
let pca = PCA::new(&us_arrests, 4, Default::default());
380+
let pca = PCA::fit(&us_arrests, 4, Default::default()).unwrap();
376381

377382
assert!(pca
378383
.eigenvectors
@@ -383,7 +388,7 @@ mod tests {
383388
assert!((pca.eigenvalues[i].abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
384389
}
385390

386-
let us_arrests_t = pca.transform(&us_arrests);
391+
let us_arrests_t = pca.transform(&us_arrests).unwrap();
387392

388393
assert!(us_arrests_t
389394
.abs()
@@ -481,13 +486,14 @@ mod tests {
481486
0.1734300877298357,
482487
];
483488

484-
let pca = PCA::new(
489+
let pca = PCA::fit(
485490
&us_arrests,
486491
4,
487492
PCAParameters {
488493
use_correlation_matrix: true,
489494
},
490-
);
495+
)
496+
.unwrap();
491497

492498
assert!(pca
493499
.eigenvectors
@@ -498,7 +504,7 @@ mod tests {
498504
assert!((pca.eigenvalues[i].abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
499505
}
500506

501-
let us_arrests_t = pca.transform(&us_arrests);
507+
let us_arrests_t = pca.transform(&us_arrests).unwrap();
502508

503509
assert!(us_arrests_t
504510
.abs()
@@ -530,7 +536,7 @@ mod tests {
530536
&[5.2, 2.7, 3.9, 1.4],
531537
]);
532538

533-
let pca = PCA::new(&iris, 4, Default::default());
539+
let pca = PCA::fit(&iris, 4, Default::default()).unwrap();
534540

535541
let deserialized_pca: PCA<f64, DenseMatrix<f64>> =
536542
serde_json::from_str(&serde_json::to_string(&pca).unwrap()).unwrap();

0 commit comments

Comments
 (0)