43
43
//! &[5.2, 2.7, 3.9, 1.4],
44
44
//! ]);
45
45
//!
46
- //! let kmeans = KMeans::new (&x, 2, Default::default()); // Fit to data, 2 clusters
47
- //! let y_hat = kmeans.predict(&x); // use the same points for prediction
46
+ //! let kmeans = KMeans::fit (&x, 2, Default::default()).unwrap( ); // Fit to data, 2 clusters
47
+ //! let y_hat = kmeans.predict(&x).unwrap() ; // use the same points for prediction
48
48
//! ```
49
49
//!
50
50
//! ## References:
@@ -60,6 +60,7 @@ use std::iter::Sum;
60
60
61
61
use serde:: { Deserialize , Serialize } ;
62
62
63
+ use crate :: error:: { FitFailedError , PredictFailedError } ;
63
64
use crate :: algorithm:: neighbour:: bbd_tree:: BBDTree ;
64
65
use crate :: linalg:: Matrix ;
65
66
use crate :: math:: distance:: euclidian:: * ;
@@ -117,18 +118,17 @@ impl<T: RealNumber + Sum> KMeans<T> {
117
118
/// * `data` - training instances to cluster
118
119
/// * `k` - number of clusters
119
120
/// * `parameters` - cluster parameters
120
- pub fn new < M : Matrix < T > > ( data : & M , k : usize , parameters : KMeansParameters ) -> KMeans < T > {
121
+ pub fn fit < M : Matrix < T > > ( data : & M , k : usize , parameters : KMeansParameters ) -> Result < KMeans < T > , FitFailedError > {
121
122
let bbd = BBDTree :: new ( data) ;
122
123
123
124
if k < 2 {
124
- panic ! ( "Invalid number of clusters: {}" , k) ;
125
+ return Err ( FitFailedError :: new ( & format ! ( "Invalid number of clusters: {}" , k) ) ) ;
125
126
}
126
127
127
128
if parameters. max_iter <= 0 {
128
- panic ! (
129
- "Invalid maximum number of iterations: {}" ,
129
+ return Err ( FitFailedError :: new ( & format ! ( "Invalid maximum number of iterations: {}" ,
130
130
parameters. max_iter
131
- ) ;
131
+ ) ) ) ;
132
132
}
133
133
134
134
let ( n, d) = data. shape ( ) ;
@@ -172,18 +172,18 @@ impl<T: RealNumber + Sum> KMeans<T> {
172
172
}
173
173
}
174
174
175
- KMeans {
175
+ Ok ( KMeans {
176
176
k : k,
177
177
y : y,
178
178
size : size,
179
179
distortion : distortion,
180
180
centroids : centroids,
181
- }
181
+ } )
182
182
}
183
183
184
184
/// Predict clusters for `x`
185
185
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
186
- pub fn predict < M : Matrix < T > > ( & self , x : & M ) -> M :: RowVector {
186
+ pub fn predict < M : Matrix < T > > ( & self , x : & M ) -> Result < M :: RowVector , PredictFailedError > {
187
187
let ( n, _) = x. shape ( ) ;
188
188
let mut result = M :: zeros ( 1 , n) ;
189
189
@@ -201,7 +201,7 @@ impl<T: RealNumber + Sum> KMeans<T> {
201
201
result. set ( 0 , i, T :: from ( best_cluster) . unwrap ( ) ) ;
202
202
}
203
203
204
- result. to_row_vector ( )
204
+ Ok ( result. to_row_vector ( ) )
205
205
}
206
206
207
207
fn kmeans_plus_plus < M : Matrix < T > > ( data : & M , k : usize ) -> Vec < usize > {
@@ -262,6 +262,20 @@ mod tests {
262
262
use super :: * ;
263
263
use crate :: linalg:: naive:: dense_matrix:: DenseMatrix ;
264
264
265
+ #[ test]
266
+ fn invalid_k ( ) {
267
+ let x = DenseMatrix :: from_2d_array ( & [
268
+ & [ 1. , 2. , 3. ] ,
269
+ & [ 4. , 5. , 6. ] ,
270
+ ] ) ;
271
+
272
+ println ! ( "{:?}" , KMeans :: fit( & x, 0 , Default :: default ( ) ) ) ;
273
+
274
+ assert ! ( KMeans :: fit( & x, 0 , Default :: default ( ) ) . is_err( ) ) ;
275
+ assert_eq ! ( "Invalid number of clusters: 1" , KMeans :: fit( & x, 1 , Default :: default ( ) ) . unwrap_err( ) . to_string( ) ) ;
276
+
277
+ }
278
+
265
279
#[ test]
266
280
fn fit_predict_iris ( ) {
267
281
let x = DenseMatrix :: from_2d_array ( & [
@@ -287,9 +301,9 @@ mod tests {
287
301
& [ 5.2 , 2.7 , 3.9 , 1.4 ] ,
288
302
] ) ;
289
303
290
- let kmeans = KMeans :: new ( & x, 2 , Default :: default ( ) ) ;
304
+ let kmeans = KMeans :: fit ( & x, 2 , Default :: default ( ) ) . unwrap ( ) ;
291
305
292
- let y = kmeans. predict ( & x) ;
306
+ let y = kmeans. predict ( & x) . unwrap ( ) ;
293
307
294
308
for i in 0 ..y. len ( ) {
295
309
assert_eq ! ( y[ i] as usize , kmeans. y[ i] ) ;
@@ -321,7 +335,7 @@ mod tests {
321
335
& [ 5.2 , 2.7 , 3.9 , 1.4 ] ,
322
336
] ) ;
323
337
324
- let kmeans = KMeans :: new ( & x, 2 , Default :: default ( ) ) ;
338
+ let kmeans = KMeans :: fit ( & x, 2 , Default :: default ( ) ) . unwrap ( ) ;
325
339
326
340
let deserialized_kmeans: KMeans < f64 > =
327
341
serde_json:: from_str ( & serde_json:: to_string ( & kmeans) . unwrap ( ) ) . unwrap ( ) ;
0 commit comments