Skip to content

Commit 9b7a2df

Browse files
Volodymyr OrlovVolodymyr Orlov
authored andcommitted
feat: adds FitFailedError and PredictFailedError
1 parent 4ba0cd3 commit 9b7a2df

File tree

3 files changed

+84
-14
lines changed

3 files changed

+84
-14
lines changed

src/cluster/kmeans.rs

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@
4343
//! &[5.2, 2.7, 3.9, 1.4],
4444
//! ]);
4545
//!
46-
//! let kmeans = KMeans::new(&x, 2, Default::default()); // Fit to data, 2 clusters
47-
//! let y_hat = kmeans.predict(&x); // use the same points for prediction
46+
//! let kmeans = KMeans::fit(&x, 2, Default::default()).unwrap(); // Fit to data, 2 clusters
47+
//! let y_hat = kmeans.predict(&x).unwrap(); // use the same points for prediction
4848
//! ```
4949
//!
5050
//! ## References:
@@ -60,6 +60,7 @@ use std::iter::Sum;
6060

6161
use serde::{Deserialize, Serialize};
6262

63+
use crate::error::{FitFailedError, PredictFailedError};
6364
use crate::algorithm::neighbour::bbd_tree::BBDTree;
6465
use crate::linalg::Matrix;
6566
use crate::math::distance::euclidian::*;
@@ -117,18 +118,17 @@ impl<T: RealNumber + Sum> KMeans<T> {
117118
/// * `data` - training instances to cluster
118119
/// * `k` - number of clusters
119120
/// * `parameters` - cluster parameters
120-
pub fn new<M: Matrix<T>>(data: &M, k: usize, parameters: KMeansParameters) -> KMeans<T> {
121+
pub fn fit<M: Matrix<T>>(data: &M, k: usize, parameters: KMeansParameters) -> Result<KMeans<T>, FitFailedError> {
121122
let bbd = BBDTree::new(data);
122123

123124
if k < 2 {
124-
panic!("Invalid number of clusters: {}", k);
125+
return Err(FitFailedError::new(&format!("Invalid number of clusters: {}", k)));
125126
}
126127

127128
if parameters.max_iter <= 0 {
128-
panic!(
129-
"Invalid maximum number of iterations: {}",
129+
return Err(FitFailedError::new(&format!("Invalid maximum number of iterations: {}",
130130
parameters.max_iter
131-
);
131+
)));
132132
}
133133

134134
let (n, d) = data.shape();
@@ -172,18 +172,18 @@ impl<T: RealNumber + Sum> KMeans<T> {
172172
}
173173
}
174174

175-
KMeans {
175+
Ok(KMeans {
176176
k: k,
177177
y: y,
178178
size: size,
179179
distortion: distortion,
180180
centroids: centroids,
181-
}
181+
})
182182
}
183183

184184
/// Predict clusters for `x`
185185
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
186-
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
186+
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, PredictFailedError> {
187187
let (n, _) = x.shape();
188188
let mut result = M::zeros(1, n);
189189

@@ -201,7 +201,7 @@ impl<T: RealNumber + Sum> KMeans<T> {
201201
result.set(0, i, T::from(best_cluster).unwrap());
202202
}
203203

204-
result.to_row_vector()
204+
Ok(result.to_row_vector())
205205
}
206206

207207
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize) -> Vec<usize> {
@@ -262,6 +262,20 @@ mod tests {
262262
use super::*;
263263
use crate::linalg::naive::dense_matrix::DenseMatrix;
264264

265+
#[test]
266+
fn invalid_k() {
267+
let x = DenseMatrix::from_2d_array(&[
268+
&[1., 2., 3.],
269+
&[4., 5., 6.],
270+
]);
271+
272+
println!("{:?}", KMeans::fit(&x, 0, Default::default()));
273+
274+
assert!(KMeans::fit(&x, 0, Default::default()).is_err());
275+
assert_eq!("Invalid number of clusters: 1", KMeans::fit(&x, 1, Default::default()).unwrap_err().to_string());
276+
277+
}
278+
265279
#[test]
266280
fn fit_predict_iris() {
267281
let x = DenseMatrix::from_2d_array(&[
@@ -287,9 +301,9 @@ mod tests {
287301
&[5.2, 2.7, 3.9, 1.4],
288302
]);
289303

290-
let kmeans = KMeans::new(&x, 2, Default::default());
304+
let kmeans = KMeans::fit(&x, 2, Default::default()).unwrap();
291305

292-
let y = kmeans.predict(&x);
306+
let y = kmeans.predict(&x).unwrap();
293307

294308
for i in 0..y.len() {
295309
assert_eq!(y[i] as usize, kmeans.y[i]);
@@ -321,7 +335,7 @@ mod tests {
321335
&[5.2, 2.7, 3.9, 1.4],
322336
]);
323337

324-
let kmeans = KMeans::new(&x, 2, Default::default());
338+
let kmeans = KMeans::fit(&x, 2, Default::default()).unwrap();
325339

326340
let deserialized_kmeans: KMeans<f64> =
327341
serde_json::from_str(&serde_json::to_string(&kmeans).unwrap()).unwrap();

src/error/mod.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
//! # Custom warnings and errors
2+
use std::error::Error;
3+
use std::fmt;
4+
5+
/// Error to be raised when model does not fits data.
6+
#[derive(Debug)]
7+
pub struct FitFailedError {
8+
details: String
9+
}
10+
11+
/// Error to be raised when model prediction cannot be calculated.
12+
#[derive(Debug)]
13+
pub struct PredictFailedError {
14+
details: String
15+
}
16+
17+
impl FitFailedError {
18+
/// Creates new instance of `FitFailedError`
19+
/// * `msg` - description of the error
20+
pub fn new(msg: &str) -> FitFailedError {
21+
FitFailedError{details: msg.to_string()}
22+
}
23+
}
24+
25+
impl fmt::Display for FitFailedError {
26+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
27+
write!(f,"{}",self.details)
28+
}
29+
}
30+
31+
impl Error for FitFailedError {
32+
fn description(&self) -> &str {
33+
&self.details
34+
}
35+
}
36+
37+
impl PredictFailedError {
38+
/// Creates new instance of `PredictFailedError`
39+
/// * `msg` - description of the error
40+
pub fn new(msg: &str) -> PredictFailedError {
41+
PredictFailedError{details: msg.to_string()}
42+
}
43+
}
44+
45+
impl fmt::Display for PredictFailedError {
46+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
47+
write!(f,"{}",self.details)
48+
}
49+
}
50+
51+
impl Error for PredictFailedError {
52+
fn description(&self) -> &str {
53+
&self.details
54+
}
55+
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,4 @@ pub mod neighbors;
8989
pub(crate) mod optimization;
9090
/// Supervised tree-based learning methods
9191
pub mod tree;
92+
pub mod error;

0 commit comments

Comments
 (0)